/*
-----------------------------------------------------------------------
Copyright: 2010-2021, imec Vision Lab, University of Antwerp
2014-2021, CWI, Amsterdam
Contact: astra@astra-toolbox.com
Website: http://www.astra-toolbox.com/
This file is part of the ASTRA Toolbox.
The ASTRA Toolbox is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
The ASTRA Toolbox is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with the ASTRA Toolbox. If not, see .
-----------------------------------------------------------------------
*/
/** \file astra_mex_plugin_c.cpp
*
* \brief Manages Python plugins.
*/
#include
#include "mexHelpFunctions.h"
#include "mexInitFunctions.h"
#include "astra/PluginAlgorithmFactory.h"
#include
using namespace std;
using namespace astra;
static void fixLapackLoading()
{
// When running in Matlab, we need to force numpy
// to use its internal lapack library instead of
// Matlab's MKL library to avoid errors. To do this,
// we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND
// and import 'numpy.linalg.lapack_lite' here. We reset
// Python's dlopen flags afterwards.
PyObject *sys = PyImport_ImportModule("sys");
if (sys != NULL) {
PyObject *curFlags = PyObject_CallMethod(sys, "getdlopenflags", NULL);
if (curFlags != NULL) {
PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i", 10); // RTLD_NOW|RTLD_DEEPBIND
if (retVal != NULL) {
PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite");
if (lapack != NULL) {
Py_DECREF(lapack);
}
PyObject *retVal2 = PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags);
if (retVal2 != NULL) {
Py_DECREF(retVal2);
}
Py_DECREF(retVal);
}
Py_DECREF(curFlags);
}
Py_DECREF(sys);
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('init');
*
* Initialize plugin support by initializing python and importing astra
*/
void astra_mex_plugin_init()
{
if(!Py_IsInitialized()){
Py_Initialize();
PyEval_InitThreads();
}
#ifndef _MSC_VER
fixLapackLoading();
#endif
// Importing astra may be overkill, since we only need to initialize
// PythonPluginAlgorithmFactory from astra.plugin_c.
PyObject *mod = PyImport_ImportModule("astra");
Py_XDECREF(mod);
}
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('get_registered');
*
* Print registered plugins.
*/
void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
if (!fact) {
mexPrintf("Plugin support not initialized.");
return;
}
std::map mp = fact->getRegisteredMap();
for(std::map::iterator it=mp.begin();it!=mp.end();it++){
mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('register', class_name);
*
* Register plugin.
*/
void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
if (!fact) {
mexPrintf("Plugin support not initialized.");
return;
}
if (2 <= nrhs) {
string class_name = mexToString(prhs[1]);
fact->registerPlugin(class_name);
}else{
mexPrintf("astra_mex_plugin('register', class_name);\n");
}
}
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('get_help', name);
*
* Get help about plugin.
*/
void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
if (!fact) {
mexPrintf("Plugin support not initialized.");
return;
}
if (2 <= nrhs) {
string name = mexToString(prhs[1]);
mexPrintf((fact->getHelp(name)+"\n").c_str());
}else{
mexPrintf("astra_mex_plugin('get_help', name);\n");
}
}
//-----------------------------------------------------------------------------------------
static void printHelp()
{
mexPrintf("Please specify a mode of operation.\n");
mexPrintf(" Valid modes: register, get_registered, get_help\n");
}
//-----------------------------------------------------------------------------------------
/**
* ... = astra_mex(type,...);
*/
void mexFunction(int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[])
{
// INPUT0: Mode
string sMode = "";
if (1 <= nrhs) {
sMode = mexToString(prhs[0]);
} else {
printHelp();
return;
}
initASTRAMex();
// SWITCH (MODE)
if (sMode == "init") {
astra_mex_plugin_init();
} else if (sMode == std::string("get_registered")) {
astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
}else if (sMode == std::string("get_help")) {
astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
}else if (sMode == std::string("register")) {
astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}
return;
}