summaryrefslogtreecommitdiffstats
path: root/matlab
diff options
context:
space:
mode:
Diffstat (limited to 'matlab')
-rw-r--r--matlab/mex/astra_mex_c.cpp51
-rw-r--r--matlab/mex/astra_mex_plugin_c.cpp86
-rw-r--r--matlab/mex/mexInitFunctions.cpp8
-rw-r--r--matlab/tools/opTomo.m81
4 files changed, 163 insertions, 63 deletions
diff --git a/matlab/mex/astra_mex_c.cpp b/matlab/mex/astra_mex_c.cpp
index fdf4f33..f499528 100644
--- a/matlab/mex/astra_mex_c.cpp
+++ b/matlab/mex/astra_mex_c.cpp
@@ -36,10 +36,14 @@ $Id$
#include "mexInitFunctions.h"
#include "astra/Globals.h"
+#include "astra/AstraObjectManager.h"
+
#ifdef ASTRA_CUDA
#include "../cuda/2d/darthelper.h"
#include "astra/CompositeGeometryManager.h"
#endif
+
+
using namespace std;
using namespace astra;
@@ -144,10 +148,51 @@ void astra_mex_version(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[
//-----------------------------------------------------------------------------------------
+void astra_mex_info(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+ if (nrhs < 2) {
+ mexErrMsgTxt("Usage: astra_mex('info', index/indices);\n");
+ return;
+ }
+
+ for (int i = 1; i < nrhs; i++) {
+ int iDataID = (int)(mxGetScalar(prhs[i]));
+ CAstraObjectManagerBase *ptr;
+ ptr = CAstraIndexManager::getSingleton().get(iDataID);
+ if (ptr) {
+ mexPrintf("%s\t%s\n", ptr->getType().c_str(), ptr->getInfo(iDataID).c_str());
+ }
+ }
+
+}
+
+//-----------------------------------------------------------------------------------------
+
+void astra_mex_delete(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+ if (nrhs < 2) {
+ mexErrMsgTxt("Usage: astra_mex('delete', index/indices);\n");
+ return;
+ }
+
+ for (int i = 1; i < nrhs; i++) {
+ int iDataID = (int)(mxGetScalar(prhs[i]));
+ CAstraObjectManagerBase *ptr;
+ ptr = CAstraIndexManager::getSingleton().get(iDataID);
+ if (ptr)
+ ptr->remove(iDataID);
+ }
+
+}
+
+
+
+//-----------------------------------------------------------------------------------------
+
static void printHelp()
{
mexPrintf("Please specify a mode of operation.\n");
- mexPrintf(" Valid modes: version, use_cuda, credits\n");
+ mexPrintf(" Valid modes: version, use_cuda, credits, set_gpu_index, info, delete\n");
}
//-----------------------------------------------------------------------------------------
@@ -178,6 +223,10 @@ void mexFunction(int nlhs, mxArray* plhs[],
astra_mex_credits(nlhs, plhs, nrhs, prhs);
} else if (sMode == std::string("set_gpu_index")) {
astra_mex_set_gpu_index(nlhs, plhs, nrhs, prhs);
+ } else if (sMode == std::string("info")) {
+ astra_mex_info(nlhs, plhs, nrhs, prhs);
+ } else if (sMode == std::string("delete")) {
+ astra_mex_delete(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}
diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp
index 177fcf4..4ed534e 100644
--- a/matlab/mex/astra_mex_plugin_c.cpp
+++ b/matlab/mex/astra_mex_plugin_c.cpp
@@ -37,9 +37,63 @@ $Id$
#include "astra/PluginAlgorithm.h"
+#include <Python.h>
+
using namespace std;
using namespace astra;
+static void fixLapackLoading()
+{
+ // When running in Matlab, we need to force numpy
+ // to use its internal lapack library instead of
+ // Matlab's MKL library to avoid errors. To do this,
+ // we set Python's dlopen flags to RTLD_NOW|RTLD_DEEPBIND
+ // and import 'numpy.linalg.lapack_lite' here. We reset
+ // Python's dlopen flags afterwards.
+ PyObject *sys = PyImport_ImportModule("sys");
+ if (sys != NULL) {
+ PyObject *curFlags = PyObject_CallMethod(sys, "getdlopenflags", NULL);
+ if (curFlags != NULL) {
+ PyObject *retVal = PyObject_CallMethod(sys, "setdlopenflags", "i", 10); // RTLD_NOW|RTLD_DEEPBIND
+ if (retVal != NULL) {
+ PyObject *lapack = PyImport_ImportModule("numpy.linalg.lapack_lite");
+ if (lapack != NULL) {
+ Py_DECREF(lapack);
+ }
+ PyObject *retVal2 = PyObject_CallMethod(sys, "setdlopenflags", "O",curFlags);
+ if (retVal2 != NULL) {
+ Py_DECREF(retVal2);
+ }
+ Py_DECREF(retVal);
+ }
+ Py_DECREF(curFlags);
+ }
+ Py_DECREF(sys);
+ }
+}
+
+//-----------------------------------------------------------------------------------------
+/** astra_mex_plugin('init');
+ *
+ * Initialize plugin support by initializing python and importing astra
+ */
+void astra_mex_plugin_init()
+{
+ if(!Py_IsInitialized()){
+ Py_Initialize();
+ PyEval_InitThreads();
+ }
+
+#ifndef _MSC_VER
+ fixLapackLoading();
+#endif
+
+ // Importing astra may be overkill, since we only need to initialize
+ // PythonPluginAlgorithmFactory from astra.plugin_c.
+ PyObject *mod = PyImport_ImportModule("astra");
+ Py_XDECREF(mod);
+}
+
//-----------------------------------------------------------------------------------------
/** astra_mex_plugin('get_registered');
@@ -48,7 +102,11 @@ using namespace astra;
*/
void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
std::map<std::string, std::string> mp = fact->getRegisteredMap();
for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){
mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());
@@ -62,9 +120,13 @@ void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const
*/
void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
if (2 <= nrhs) {
string class_name = mexToString(prhs[1]);
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
fact->registerPlugin(class_name);
}else{
mexPrintf("astra_mex_plugin('register', class_name);\n");
@@ -78,9 +140,13 @@ void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArra
*/
void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getFactory();
+ if (!fact) {
+ mexPrintf("Plugin support not initialized.");
+ return;
+ }
if (2 <= nrhs) {
string name = mexToString(prhs[1]);
- astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
mexPrintf((fact->getHelp(name)+"\n").c_str());
}else{
mexPrintf("astra_mex_plugin('get_help', name);\n");
@@ -116,12 +182,14 @@ void mexFunction(int nlhs, mxArray* plhs[],
initASTRAMex();
// SWITCH (MODE)
- if (sMode == std::string("get_registered")) {
- astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
- }else if (sMode == std::string("get_help")) {
- astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
- }else if (sMode == std::string("register")) {
- astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
+ if (sMode == "init") {
+ astra_mex_plugin_init();
+ } else if (sMode == std::string("get_registered")) {
+ astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("get_help")) {
+ astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("register")) {
+ astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
} else {
printHelp();
}
diff --git a/matlab/mex/mexInitFunctions.cpp b/matlab/mex/mexInitFunctions.cpp
index bd3df2c..7245af2 100644
--- a/matlab/mex/mexInitFunctions.cpp
+++ b/matlab/mex/mexInitFunctions.cpp
@@ -23,5 +23,13 @@ void initASTRAMex(){
if(!astra::CLogger::setCallbackScreen(&logCallBack)){
mexErrMsgTxt("Error initializing mex functions.");
}
+
mexIsInitialized=true;
+
+
+ // If we have support for plugins, initialize them.
+ // (NB: Call this after setting mexIsInitialized, to avoid recursively
+ // calling initASTRAMex)
+ mexEvalString("if exist('astra_mex_plugin_c') == 3; astra_mex_plugin_c('init'); end");
+
}
diff --git a/matlab/tools/opTomo.m b/matlab/tools/opTomo.m
index 71dfb1e..04b3634 100644
--- a/matlab/tools/opTomo.m
+++ b/matlab/tools/opTomo.m
@@ -44,11 +44,9 @@ classdef opTomo < opSpot
vol_id
fp_alg_id
bp_alg_id
+ proj_id
% ASTRA IDs handle
astra_handle
- % geometries
- proj_geom;
- vol_geom;
end % properties
properties ( SetAccess = private, GetAccess = public )
@@ -139,6 +137,17 @@ classdef opTomo < opSpot
error(['Only type ' 39 'cuda' 39 ' is supported ' ...
'for 3D geometries.'])
end
+
+ % setup projector
+ cfg = astra_struct('cuda3d');
+ cfg.ProjectionGeometry = proj_geom;
+ cfg.VolumeGeometry = vol_geom;
+ cfg.option.GPUindex = gpu_index;
+
+ % create projector
+ op.proj_id = astra_mex_projector3d('create', cfg);
+ % create handle to ASTRA object, for cleaning up
+ op.astra_handle = opTomo_helper_handle(op.proj_id);
% create a function handle
op.funHandle = @opTomo_intrnl3D;
@@ -148,8 +157,6 @@ classdef opTomo < opSpot
% pass object properties
op.proj_size = proj_size;
op.vol_size = vol_size;
- op.proj_geom = proj_geom;
- op.vol_geom = vol_geom;
op.cflag = false;
op.sweepflag = false;
@@ -169,10 +176,12 @@ classdef opTomo < opSpot
if issparse(x)
x = full(x);
end
-
- % convert input to single
- if isa(x, 'single') == false
+
+ if isa(x, 'double')
+ isdouble = true;
x = single(x);
+ else
+ isdouble = false;
end
% the multiplication
@@ -180,6 +189,10 @@ classdef opTomo < opSpot
% make sure output is column vector
y = y(:);
+
+ if isdouble
+ y = double(y);
+ end
end % multiply
@@ -194,7 +207,7 @@ classdef opTomo < opSpot
function y = opTomo_intrnl2D(op,x,mode)
if mode == 1
- % X is passed as a vector, reshape it into an image.
+ % x is passed as a vector, reshape it into an image.
x = reshape(x, op.vol_size);
% Matlab data copied to ASTRA data
@@ -206,7 +219,7 @@ classdef opTomo < opSpot
% retrieve Matlab array
y = astra_mex_data2d('get_single', op.sino_id);
else
- % X is passed as a vector, reshape it into a sinogram.
+ % x is passed as a vector, reshape it into a sinogram.
x = reshape(x, op.proj_size);
% Matlab data copied to ASTRA data
@@ -218,6 +231,7 @@ classdef opTomo < opSpot
% retrieve Matlab array
y = astra_mex_data2d('get_single', op.vol_id);
end
+
end % opTomo_intrnl2D
@@ -225,55 +239,16 @@ classdef opTomo < opSpot
function y = opTomo_intrnl3D(op,x,mode)
if mode == 1
- % X is passed as a vector, reshape it into an image
+ % x is passed as a vector, reshape it into an image
x = reshape(x, op.vol_size);
- % initialize output
- y = zeros(op.proj_size, 'single');
-
- % link matlab array to ASTRA
- vol_id = astra_mex_data3d_c('link', '-vol', op.vol_geom, x, 0);
- sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, y, 1);
-
- % initialize fp algorithm
- cfg = astra_struct('FP3D_CUDA');
- cfg.ProjectionDataId = sino_id;
- cfg.VolumeDataId = vol_id;
-
- alg_id = astra_mex_algorithm('create', cfg);
-
% forward projection
- astra_mex_algorithm('iterate', alg_id);
-
- % cleanup
- astra_mex_data3d('delete', vol_id);
- astra_mex_data3d('delete', sino_id);
- astra_mex_algorithm('delete', alg_id);
+ y = astra_mex_direct('FP3D', op.proj_id, x);
else
- % X is passed as a vector, reshape it into projection data
+ % x is passed as a vector, reshape it into projection data
x = reshape(x, op.proj_size);
-
- % initialize output
- y = zeros(op.vol_size,'single');
-
- % link matlab array to ASTRA
- vol_id = astra_mex_data3d_c('link', '-vol', op.vol_geom, y, 1);
- sino_id = astra_mex_data3d_c('link', '-sino', op.proj_geom, x, 0);
- % initialize bp algorithm
- cfg = astra_struct('BP3D_CUDA');
- cfg.ProjectionDataId = sino_id;
- cfg.ReconstructionDataId = vol_id;
-
- alg_id = astra_mex_algorithm('create', cfg);
-
- % backprojection
- astra_mex_algorithm('iterate', alg_id);
-
- % cleanup
- astra_mex_data3d('delete', vol_id);
- astra_mex_data3d('delete', sino_id);
- astra_mex_algorithm('delete', alg_id);
+ y = astra_mex_direct('BP3D', op.proj_id, x);
end
end % opTomo_intrnl3D