diff options
Diffstat (limited to 'matlab/mex')
-rw-r--r-- | matlab/mex/astra_mex_c.cpp | 51 | ||||
-rw-r--r-- | matlab/mex/astra_mex_plugin_c.cpp | 86 | ||||
-rw-r--r-- | matlab/mex/mexInitFunctions.cpp | 8 |
3 files changed, 135 insertions, 10 deletions
diff --git a/matlab/mex/astra_mex_c.cpp b/matlab/mex/astra_mex_c.cpp index fdf4f33..f499528 100644 --- a/matlab/mex/astra_mex_c.cpp +++ b/matlab/mex/astra_mex_c.cpp @@ -36,10 +36,14 @@ $Id$ #include "mexInitFunctions.h" #include "astra/Globals.h" +#include "astra/AstraObjectManager.h" + #ifdef ASTRA_CUDA #include "../cuda/2d/darthelper.h" #include "astra/CompositeGeometryManager.h" #endif + + using namespace std; using namespace astra; @@ -144,10 +148,51 @@ void astra_mex_version(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[ //----------------------------------------------------------------------------------------- +void astra_mex_info(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ + if (nrhs < 2) { + mexErrMsgTxt("Usage: astra_mex('info', index/indices);\n"); + return; + } + + for (int i = 1; i < nrhs; i++) { + int iDataID = (int)(mxGetScalar(prhs[i])); + CAstraObjectManagerBase *ptr; + ptr = CAstraIndexManager::getSingleton().get(iDataID); + if (ptr) { + mexPrintf("%s\t%s\n", ptr->getType().c_str(), ptr->getInfo(iDataID).c_str()); + } + } + +} + +//----------------------------------------------------------------------------------------- + +void astra_mex_delete(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) +{ + if (nrhs < 2) { + mexErrMsgTxt("Usage: astra_mex('delete', index/indices);\n"); + return; + } + + for (int i = 1; i < nrhs; i++) { + int iDataID = (int)(mxGetScalar(prhs[i])); + CAstraObjectManagerBase *ptr; + ptr = CAstraIndexManager::getSingleton().get(iDataID); + if (ptr) + ptr->remove(iDataID); + } + +} + + + +//----------------------------------------------------------------------------------------- + static void printHelp() { mexPrintf("Please specify a mode of operation.\n"); - mexPrintf(" Valid modes: version, use_cuda, credits\n"); + mexPrintf(" Valid modes: version, use_cuda, credits, set_gpu_index, info, delete\n"); } //----------------------------------------------------------------------------------------- @@ -178,6 +223,10 @@ void mexFunction(int nlhs, mxArray* plhs[], astra_mex_credits(nlhs, plhs, nrhs, prhs); } else if (sMode == std::string("set_gpu_index")) { astra_mex_set_gpu_index(nlhs, plhs, nrhs, prhs); + } else if (sMode == std::string("info")) { + astra_mex_info(nlhs, plhs, nrhs, prhs); + } else if (sMode == std::string("delete")) { + astra_mex_delete(nlhs, plhs, nrhs, prhs); } else { printHelp(); } diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp index 177fcf4..4ed534e 100644 --- a/matlab/mex/astra_mex_plugin_c.cpp +++ b/matlab/mex/astra_mex_plugin_c.cpp @@ -37,9 +37,63 @@ $Id$ #include "astra/PluginAlgorithm.h" +#include <Python.h> + using namespace std; using namespace astra; +static void fixLapackLoading() +{ + // When running in Matlab, we need to force numpy + // to use its internal lapack library instead of + // Matlab's MKL library to avoid errors. To do this, + // we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND + // and import 'numpy.linalg.lapack_lite' here. We reset + // Python's dlopen flags afterwards. + PyObject *sys = PyImport_ImportModule("sys"); + if (sys != NULL) { + PyObject *curFlags = PyObject_CallMethod(sys, "getdlopenflags", NULL); + if (curFlags != NULL) { + PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i", 10); // RTLD_NOW|RTLD_DEEPBIND + if (retVal != NULL) { + PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite"); + if (lapack != NULL) { + Py_DECREF(lapack); + } + PyObject *retVal2 = PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags); + if (retVal2 != NULL) { + Py_DECREF(retVal2); + } + Py_DECREF(retVal); + } + Py_DECREF(curFlags); + } + Py_DECREF(sys); + } +} + +//----------------------------------------------------------------------------------------- +/** astra_mex_plugin('init'); + * + * Initialize plugin support by initializing python and importing astra + */ +void astra_mex_plugin_init() +{ + if(!Py_IsInitialized()){ + Py_Initialize(); + PyEval_InitThreads(); + } + +#ifndef _MSC_VER + fixLapackLoading(); +#endif + + // Importing astra may be overkill, since we only need to initialize + // PythonPluginAlgorithmFactory from astra.plugin_c. + PyObject *mod = PyImport_ImportModule("astra"); + Py_XDECREF(mod); +} + //----------------------------------------------------------------------------------------- /** astra_mex_plugin('get_registered'); @@ -48,7 +102,11 @@ using namespace astra; */ void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { - astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); + if (!fact) { + mexPrintf("Plugin support not initialized."); + return; + } std::map<std::string, std::string> mp = fact->getRegisteredMap(); for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){ mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str()); @@ -62,9 +120,13 @@ void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const */ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); + if (!fact) { + mexPrintf("Plugin support not initialized."); + return; + } if (2 <= nrhs) { string class_name = mexToString(prhs[1]); - astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); fact->registerPlugin(class_name); }else{ mexPrintf("astra_mex_plugin('register', class_name);\n"); @@ -78,9 +140,13 @@ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArra */ void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { + astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory(); + if (!fact) { + mexPrintf("Plugin support not initialized."); + return; + } if (2 <= nrhs) { string name = mexToString(prhs[1]); - astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr(); mexPrintf((fact->getHelp(name)+"\n").c_str()); }else{ mexPrintf("astra_mex_plugin('get_help', name);\n"); @@ -116,12 +182,14 @@ void mexFunction(int nlhs, mxArray* plhs[], initASTRAMex(); // SWITCH (MODE) - if (sMode == std::string("get_registered")) { - astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs); - }else if (sMode == std::string("get_help")) { - astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs); - }else if (sMode == std::string("register")) { - astra_mex_plugin_register(nlhs, plhs, nrhs, prhs); + if (sMode == "init") { + astra_mex_plugin_init(); + } else if (sMode == std::string("get_registered")) { + astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs); + }else if (sMode == std::string("get_help")) { + astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs); + }else if (sMode == std::string("register")) { + astra_mex_plugin_register(nlhs, plhs, nrhs, prhs); } else { printHelp(); } diff --git a/matlab/mex/mexInitFunctions.cpp b/matlab/mex/mexInitFunctions.cpp index bd3df2c..7245af2 100644 --- a/matlab/mex/mexInitFunctions.cpp +++ b/matlab/mex/mexInitFunctions.cpp @@ -23,5 +23,13 @@ void initASTRAMex(){ if(!astra::CLogger::setCallbackScreen(&logCallBack)){ mexErrMsgTxt("Error initializing mex functions."); } + mexIsInitialized=true; + + + // If we have support for plugins, initialize them. + // (NB: Call this after setting mexIsInitialized, to avoid recursively + // calling initASTRAMex) + mexEvalString("if exist('astra_mex_plugin_c') == 3; astra_mex_plugin_c('init'); end"); + } |