summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDaniel M. Pelt <D.M.Pelt@cwi.nl>2015-06-19 22:28:06 +0200
committerWillem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>2015-07-23 11:57:15 +0200
commit18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3 (patch)
tree8919012d1c610eaf6b2e8c157082a85fb71137a4
parent9e077994b382b2df63e4b79dd2afebc50366d419 (diff)
downloadastra-18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3.tar.gz
astra-18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3.tar.bz2
astra-18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3.tar.xz
astra-18b6d25f7e4f0943b3592f3bb4f6ca5ed9c285d3.zip
Add support for Python algorithm plugins
-rw-r--r--build/linux/Makefile.in16
-rw-r--r--include/astra/AstraObjectFactory.h13
-rw-r--r--include/astra/PluginAlgorithm.h85
-rw-r--r--matlab/mex/astra_mex_plugin_c.cpp139
-rw-r--r--python/astra/__init__.py1
-rw-r--r--python/astra/plugin.py95
-rw-r--r--python/astra/plugin_c.pyx59
-rw-r--r--python/astra/utils.pyx72
-rw-r--r--python/docSRC/index.rst1
-rw-r--r--python/docSRC/plugins.rst8
-rw-r--r--samples/python/s018_plugin.py138
-rw-r--r--src/PluginAlgorithm.cpp294
12 files changed, 851 insertions, 70 deletions
diff --git a/build/linux/Makefile.in b/build/linux/Makefile.in
index 2d862f2..e209fa7 100644
--- a/build/linux/Makefile.in
+++ b/build/linux/Makefile.in
@@ -50,11 +50,17 @@ LDFLAGS+=-fopenmp
endif
ifeq ($(python),yes)
-PYCPPFLAGS = ${CPPFLAGS}
+PYTHON = @PYTHON@
+PYLIBDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_config_var; import six; six.print_(get_config_var("LIBDIR"))')
+PYINCDIR = $(shell $(PYTHON) -c 'from distutils.sysconfig import get_python_inc; import six; six.print_(get_python_inc())')
+PYLIBVER = `basename $(PYINCDIR)`
+CPPFLAGS += -DASTRA_PYTHON -I$(PYINCDIR)
+PYCPPFLAGS = $(CPPFLAGS)
PYCPPFLAGS += -I../include
-PYLDFLAGS = ${LDFLAGS}
+PYLDFLAGS = $(LDFLAGS)
PYLDFLAGS += -L../build/linux/.libs
-PYTHON = @PYTHON@
+LIBS += -l$(PYLIBVER)
+LDFLAGS += -L$(PYLIBDIR)
endif
BOOST_CPPFLAGS=
@@ -234,6 +240,10 @@ MATLAB_MEX=\
matlab/mex/astra_mex_log_c.$(MEXSUFFIX) \
matlab/mex/astra_mex_data3d_c.$(MEXSUFFIX)
+ifeq ($(python),yes)
+ALL_OBJECTS+=src/PluginAlgorithm.lo
+MATLAB_MEX+=matlab/mex/astra_mex_plugin_c.$(MEXSUFFIX)
+endif
OBJECT_DIRS = src/ tests/ cuda/2d/ cuda/3d/ matlab/mex/ ./
DEPDIRS = $(addsuffix $(DEPDIR),$(OBJECT_DIRS))
diff --git a/include/astra/AstraObjectFactory.h b/include/astra/AstraObjectFactory.h
index 356acf9..325989e 100644
--- a/include/astra/AstraObjectFactory.h
+++ b/include/astra/AstraObjectFactory.h
@@ -40,6 +40,10 @@ $Id$
#include "AlgorithmTypelist.h"
+#ifdef ASTRA_PYTHON
+#include "PluginAlgorithm.h"
+#endif
+
namespace astra {
@@ -147,6 +151,15 @@ T* CAstraObjectFactory<T, TypeList>::create(const Config& _cfg)
*/
class _AstraExport CAlgorithmFactory : public CAstraObjectFactory<CAlgorithm, AlgorithmTypeList> {};
+#ifdef ASTRA_PYTHON
+template <>
+inline CAlgorithm* CAstraObjectFactory<CAlgorithm, AlgorithmTypeList>::findPlugin(std::string _sType)
+ {
+ CPluginAlgorithmFactory *fac = CPluginAlgorithmFactory::getSingletonPtr();
+ return fac->getPlugin(_sType);
+ }
+#endif
+
/**
* Class used to create 2D projectors from a string or a config object
*/
diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h
new file mode 100644
index 0000000..7d6c64a
--- /dev/null
+++ b/include/astra/PluginAlgorithm.h
@@ -0,0 +1,85 @@
+/*
+-----------------------------------------------------------------------
+Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp
+ 2014-2015, CWI, Amsterdam
+
+Contact: astra@uantwerpen.be
+Website: http://sf.net/projects/astra-toolbox
+
+This file is part of the ASTRA Toolbox.
+
+
+The ASTRA Toolbox is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+The ASTRA Toolbox is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+
+-----------------------------------------------------------------------
+$Id$
+*/
+
+#ifndef _INC_ASTRA_PLUGINALGORITHM
+#define _INC_ASTRA_PLUGINALGORITHM
+
+#ifdef ASTRA_PYTHON
+
+#include <Python.h>
+#include "bytesobject.h"
+#include "astra/Algorithm.h"
+#include "astra/Singleton.h"
+#include "astra/XMLDocument.h"
+#include "astra/XMLNode.h"
+
+namespace astra {
+class _AstraExport CPluginAlgorithm : public CAlgorithm {
+
+public:
+
+ CPluginAlgorithm(PyObject* pyclass);
+ ~CPluginAlgorithm();
+
+ bool initialize(const Config& _cfg);
+ void run(int _iNrIterations);
+
+private:
+ PyObject * instance;
+
+};
+
+class _AstraExport CPluginAlgorithmFactory : public Singleton<CPluginAlgorithmFactory> {
+
+public:
+
+ CPluginAlgorithmFactory();
+ ~CPluginAlgorithmFactory();
+
+ CPluginAlgorithm * getPlugin(std::string name);
+
+ bool registerPlugin(std::string name, std::string className);
+ bool registerPluginClass(std::string name, PyObject * className);
+
+ PyObject * getRegistered();
+
+ std::string getHelp(std::string name);
+
+private:
+ PyObject * pluginDict;
+ PyObject *ospath, *inspect, *six, *astra;
+ std::vector<std::string> getPluginPathList();
+};
+
+PyObject* XMLNode2dict(XMLNode node);
+
+}
+
+#endif
+
+#endif \ No newline at end of file
diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp
new file mode 100644
index 0000000..2d9b9a0
--- /dev/null
+++ b/matlab/mex/astra_mex_plugin_c.cpp
@@ -0,0 +1,139 @@
+/*
+-----------------------------------------------------------------------
+Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp
+ 2014-2015, CWI, Amsterdam
+
+Contact: astra@uantwerpen.be
+Website: http://sf.net/projects/astra-toolbox
+
+This file is part of the ASTRA Toolbox.
+
+
+The ASTRA Toolbox is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+The ASTRA Toolbox is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+
+-----------------------------------------------------------------------
+$Id$
+*/
+
+/** \file astra_mex_plugin_c.cpp
+ *
+ * \brief Manages Python plugins.
+ */
+
+#include <mex.h>
+#include "mexHelpFunctions.h"
+#include "mexInitFunctions.h"
+
+#include "astra/PluginAlgorithm.h"
+
+#include "Python.h"
+#include "bytesobject.h"
+
+using namespace std;
+using namespace astra;
+
+
+//-----------------------------------------------------------------------------------------
+/** astra_mex_plugin('get_registered');
+ *
+ * Print registered plugins.
+ */
+void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
+ PyObject *dict = fact->getRegistered();
+ PyObject *key, *value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(dict, &pos, &key, &value)) {
+ mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value));
+ }
+ Py_DECREF(dict);
+}
+
+//-----------------------------------------------------------------------------------------
+/** astra_mex_plugin('register', name, class_name);
+ *
+ * Register plugin.
+ */
+void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+ if (3 <= nrhs) {
+ string name = mexToString(prhs[1]);
+ string class_name = mexToString(prhs[2]);
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
+ fact->registerPlugin(name, class_name);
+ }else{
+ mexPrintf("astra_mex_plugin('register', name, class_name);\n");
+ }
+}
+
+//-----------------------------------------------------------------------------------------
+/** astra_mex_plugin('get_help', name);
+ *
+ * Get help about plugin.
+ */
+void astra_mex_plugin_get_help(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
+{
+ if (2 <= nrhs) {
+ string name = mexToString(prhs[1]);
+ astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
+ mexPrintf((fact->getHelp(name)+"\n").c_str());
+ }else{
+ mexPrintf("astra_mex_plugin('get_help', name);\n");
+ }
+}
+
+
+//-----------------------------------------------------------------------------------------
+
+static void printHelp()
+{
+ mexPrintf("Please specify a mode of operation.\n");
+ mexPrintf(" Valid modes: register, get_registered, get_help\n");
+}
+
+//-----------------------------------------------------------------------------------------
+/**
+ * ... = astra_mex(type,...);
+ */
+void mexFunction(int nlhs, mxArray* plhs[],
+ int nrhs, const mxArray* prhs[])
+{
+
+ // INPUT0: Mode
+ string sMode = "";
+ if (1 <= nrhs) {
+ sMode = mexToString(prhs[0]);
+ } else {
+ printHelp();
+ return;
+ }
+
+ initASTRAMex();
+
+ // SWITCH (MODE)
+ if (sMode == std::string("get_registered")) {
+ astra_mex_plugin_get_registered(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("get_help")) {
+ astra_mex_plugin_get_help(nlhs, plhs, nrhs, prhs);
+ }else if (sMode == std::string("register")) {
+ astra_mex_plugin_register(nlhs, plhs, nrhs, prhs);
+ } else {
+ printHelp();
+ }
+
+ return;
+}
+
+
diff --git a/python/astra/__init__.py b/python/astra/__init__.py
index 6c15d30..10ed74d 100644
--- a/python/astra/__init__.py
+++ b/python/astra/__init__.py
@@ -34,6 +34,7 @@ from . import algorithm
from . import projector
from . import projector3d
from . import matrix
+from . import plugin
from . import log
from .optomo import OpTomo
diff --git a/python/astra/plugin.py b/python/astra/plugin.py
new file mode 100644
index 0000000..ccdb2cb
--- /dev/null
+++ b/python/astra/plugin.py
@@ -0,0 +1,95 @@
+#-----------------------------------------------------------------------
+#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam
+#
+#Author: Daniel M. Pelt
+#Contact: D.M.Pelt@cwi.nl
+#Website: http://dmpelt.github.io/pyastratoolbox/
+#
+#
+#This file is part of the Python interface to the
+#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
+#
+#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
+#it under the terms of the GNU General Public License as published by
+#the Free Software Foundation, either version 3 of the License, or
+#(at your option) any later version.
+#
+#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
+#but WITHOUT ANY WARRANTY; without even the implied warranty of
+#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+#GNU General Public License for more details.
+#
+#You should have received a copy of the GNU General Public License
+#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+#
+#-----------------------------------------------------------------------
+
+from . import plugin_c as p
+from . import log
+
+class base(object):
+
+ def astra_init(self, cfg):
+ try:
+ try:
+ req = self.required_options
+ except AttributeError:
+ log.warn("Plugin '" + self.__class__.__name__ + "' does not specify required options")
+ req = {}
+
+ try:
+ opt = self.optional_options
+ except AttributeError:
+ log.warn("Plugin '" + self.__class__.__name__ + "' does not specify optional options")
+ opt = {}
+
+ try:
+ optDict = cfg['options']
+ except KeyError:
+ optDict = {}
+
+ cfgKeys = set(optDict.keys())
+ reqKeys = set(req)
+ optKeys = set(opt)
+
+ if not reqKeys.issubset(cfgKeys):
+ for key in reqKeys.difference(cfgKeys):
+ log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified")
+ raise ValueError("Missing required options")
+
+ if not cfgKeys.issubset(reqKeys | optKeys):
+ log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys))))
+
+ self.initialize(cfg)
+ except Exception as e:
+ log.error(str(e))
+ raise
+
+def register(name, className):
+ """Register plugin with ASTRA.
+
+ :param name: Plugin name to register
+ :type name: :class:`str`
+ :param className: Class name or class object to register
+ :type className: :class:`str` or :class:`class`
+
+ """
+ p.register(name,className)
+
+def get_registered():
+ """Get dictionary of registered plugins.
+
+ :returns: :class:`dict` -- Registered plugins.
+
+ """
+ return p.get_registered()
+
+def get_help(name):
+ """Get help for registered plugin.
+
+ :param name: Plugin name to get help for
+ :type name: :class:`str`
+ :returns: :class:`str` -- Help string (docstring).
+
+ """
+ return p.get_help(name) \ No newline at end of file
diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx
new file mode 100644
index 0000000..91b3cd5
--- /dev/null
+++ b/python/astra/plugin_c.pyx
@@ -0,0 +1,59 @@
+#-----------------------------------------------------------------------
+#Copyright 2013 Centrum Wiskunde & Informatica, Amsterdam
+#
+#Author: Daniel M. Pelt
+#Contact: D.M.Pelt@cwi.nl
+#Website: http://dmpelt.github.io/pyastratoolbox/
+#
+#
+#This file is part of the Python interface to the
+#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
+#
+#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
+#it under the terms of the GNU General Public License as published by
+#the Free Software Foundation, either version 3 of the License, or
+#(at your option) any later version.
+#
+#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
+#but WITHOUT ANY WARRANTY; without even the implied warranty of
+#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+#GNU General Public License for more details.
+#
+#You should have received a copy of the GNU General Public License
+#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+#
+#-----------------------------------------------------------------------
+# distutils: language = c++
+# distutils: libraries = astra
+
+import six
+import inspect
+
+from libcpp.string cimport string
+from libcpp cimport bool
+
+cdef CPluginAlgorithmFactory *fact = getSingletonPtr()
+
+from . import utils
+
+cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
+ cdef cppclass CPluginAlgorithmFactory:
+ bool registerPlugin(string name, string className)
+ bool registerPluginClass(string name, object className)
+ object getRegistered()
+ string getHelp(string name)
+
+cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory":
+ cdef CPluginAlgorithmFactory* getSingletonPtr()
+
+def register(name, className):
+ if inspect.isclass(className):
+ fact.registerPluginClass(six.b(name), className)
+ else:
+ fact.registerPlugin(six.b(name), six.b(className))
+
+def get_registered():
+ return fact.getRegistered()
+
+def get_help(name):
+ return utils.wrap_from_bytes(fact.getHelp(six.b(name)))
diff --git a/python/astra/utils.pyx b/python/astra/utils.pyx
index ddb37aa..3746b8e 100644
--- a/python/astra/utils.pyx
+++ b/python/astra/utils.pyx
@@ -30,7 +30,6 @@ cimport numpy as np
import numpy as np
import six
from libcpp.string cimport string
-from libcpp.list cimport list
from libcpp.vector cimport vector
from cython.operator cimport dereference as deref, preincrement as inc
from cpython.version cimport PY_MAJOR_VERSION
@@ -40,6 +39,9 @@ from .PyXMLDocument cimport XMLDocument
from .PyXMLDocument cimport XMLNode
from .PyIncludes cimport *
+cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
+ object XMLNode2dict(XMLNode)
+
cdef Config * dictToConfig(string rootname, dc):
cdef Config * cfg = new Config()
@@ -91,6 +93,8 @@ cdef void readDict(XMLNode root, _dc):
dc = convert_item(_dc)
for item in dc:
val = dc[item]
+ if isinstance(val, list):
+ val = np.array(val,dtype=np.float64)
if isinstance(val, np.ndarray):
if val.size == 0:
break
@@ -142,69 +146,3 @@ cdef void readOptions(XMLNode node, dc):
cdef configToDict(Config *cfg):
return XMLNode2dict(cfg.self)
-def castString3(input):
- return input.decode('utf-8')
-
-def castString2(input):
- return input
-
-if six.PY3:
- castString = castString3
-else:
- castString = castString2
-
-def stringToPythonValue(inputIn):
- input = castString(inputIn)
- # matrix
- if ';' in input:
- row_strings = input.split(';')
- col_strings = row_strings[0].split(',')
- nRows = len(row_strings)
- nCols = len(col_strings)
-
- out = np.empty((nRows,nCols))
- for ridx, row in enumerate(row_strings):
- col_strings = row.split(',')
- for cidx, col in enumerate(col_strings):
- out[ridx,cidx] = float(col)
- return out
-
- # vector
- if ',' in input:
- items = input.split(',')
- out = np.empty(len(items))
- for idx,item in enumerate(items):
- out[idx] = float(item)
- return out
-
- try:
- # integer
- return int(input)
- except ValueError:
- try:
- #float
- return float(input)
- except ValueError:
- # string
- return str(input)
-
-
-cdef XMLNode2dict(XMLNode node):
- cdef XMLNode subnode
- cdef list[XMLNode] nodes
- cdef list[XMLNode].iterator it
- dct = {}
- opts = {}
- if node.hasAttribute(six.b('type')):
- dct['type'] = castString(node.getAttribute(six.b('type')))
- nodes = node.getNodes()
- it = nodes.begin()
- while it != nodes.end():
- subnode = deref(it)
- if castString(subnode.getName())=="Option":
- opts[castString(subnode.getAttribute('key'))] = stringToPythonValue(subnode.getAttribute('value'))
- else:
- dct[castString(subnode.getName())] = stringToPythonValue(subnode.getContent())
- inc(it)
- if len(opts)>0: dct['options'] = opts
- return dct
diff --git a/python/docSRC/index.rst b/python/docSRC/index.rst
index b7cc6d6..dcc6590 100644
--- a/python/docSRC/index.rst
+++ b/python/docSRC/index.rst
@@ -19,6 +19,7 @@ Contents:
creators
functions
operator
+ plugins
matlab
astra
.. astra
diff --git a/python/docSRC/plugins.rst b/python/docSRC/plugins.rst
new file mode 100644
index 0000000..dc7c607
--- /dev/null
+++ b/python/docSRC/plugins.rst
@@ -0,0 +1,8 @@
+Plugins: the :mod:`plugin` module
+=========================================
+
+.. automodule:: astra.plugin
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
new file mode 100644
index 0000000..6677930
--- /dev/null
+++ b/samples/python/s018_plugin.py
@@ -0,0 +1,138 @@
+#-----------------------------------------------------------------------
+#Copyright 2015 Centrum Wiskunde & Informatica, Amsterdam
+#
+#Author: Daniel M. Pelt
+#Contact: D.M.Pelt@cwi.nl
+#Website: http://dmpelt.github.io/pyastratoolbox/
+#
+#
+#This file is part of the Python interface to the
+#All Scale Tomographic Reconstruction Antwerp Toolbox ("ASTRA Toolbox").
+#
+#The Python interface to the ASTRA Toolbox is free software: you can redistribute it and/or modify
+#it under the terms of the GNU General Public License as published by
+#the Free Software Foundation, either version 3 of the License, or
+#(at your option) any later version.
+#
+#The Python interface to the ASTRA Toolbox is distributed in the hope that it will be useful,
+#but WITHOUT ANY WARRANTY; without even the implied warranty of
+#MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+#GNU General Public License for more details.
+#
+#You should have received a copy of the GNU General Public License
+#along with the Python interface to the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+#
+#-----------------------------------------------------------------------
+
+import astra
+import numpy as np
+import six
+
+# Define the plugin class (has to subclass astra.plugin.base)
+# Note that usually, these will be defined in a separate package/module
+class SIRTPlugin(astra.plugin.base):
+ """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm.
+
+ Optional options:
+
+ 'rel_factor': relaxation factor
+ """
+ required_options=[]
+ optional_options=['rel_factor']
+
+ def initialize(self,cfg):
+ self.W = astra.OpTomo(cfg['ProjectorId'])
+ self.vid = cfg['ReconstructionDataId']
+ self.sid = cfg['ProjectionDataId']
+ try:
+ self.rel = cfg['option']['rel_factor']
+ except KeyError:
+ self.rel = 1
+
+ def run(self, its):
+ v = astra.data2d.get_shared(self.vid)
+ s = astra.data2d.get_shared(self.sid)
+ W = self.W
+ for i in range(its):
+ v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size
+
+if __name__=='__main__':
+
+ vol_geom = astra.create_vol_geom(256, 256)
+ proj_geom = astra.create_proj_geom('parallel', 1.0, 384, np.linspace(0,np.pi,180,False))
+
+ # As before, create a sinogram from a phantom
+ import scipy.io
+ P = scipy.io.loadmat('phantom.mat')['phantom256']
+ proj_id = astra.create_projector('cuda',proj_geom,vol_geom)
+
+ # construct the OpTomo object
+ W = astra.OpTomo(proj_id)
+
+ sinogram = W * P
+ sinogram = sinogram.reshape([180, 384])
+
+ # Register the plugin with ASTRA
+ # A default set of plugins to load can be defined in:
+ # - /etc/astra-toolbox/plugins.txt
+ # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt
+ # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt
+ # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt
+ # In these files, create a separate line for each plugin with:
+ # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS]
+ #
+ # So in this case, it would be a line:
+ # SIRT-PLUGIN s018_plugin.SIRTPlugin
+ #
+ astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin')
+
+ # To get help on a registered plugin, use get_help
+ six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
+
+ # Create data structures
+ sid = astra.data2d.create('-sino', proj_geom, sinogram)
+ vid = astra.data2d.create('-vol', vol_geom)
+
+ # Create config using plugin name
+ cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg['ProjectorId'] = proj_id
+ cfg['ProjectionDataId'] = sid
+ cfg['ReconstructionDataId'] = vid
+
+ # Create algorithm object
+ alg_id = astra.algorithm.create(cfg)
+
+ # Run algorithm for 100 iterations
+ astra.algorithm.run(alg_id, 100)
+
+ # Get reconstruction
+ rec = astra.data2d.get(vid)
+
+ # Options for the plugin go in cfg['option']
+ cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg['ProjectorId'] = proj_id
+ cfg['ProjectionDataId'] = sid
+ cfg['ReconstructionDataId'] = vid
+ cfg['option'] = {}
+ cfg['option']['rel_factor'] = 1.5
+ alg_id_rel = astra.algorithm.create(cfg)
+ astra.algorithm.run(alg_id_rel, 100)
+ rec_rel = astra.data2d.get(vid)
+
+ # We can also use OpTomo to call the plugin
+ rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5})
+
+ import pylab as pl
+ pl.gray()
+ pl.figure(1)
+ pl.imshow(rec,vmin=0,vmax=1)
+ pl.figure(2)
+ pl.imshow(rec_rel,vmin=0,vmax=1)
+ pl.figure(3)
+ pl.imshow(rec_op,vmin=0,vmax=1)
+ pl.show()
+
+ # Clean up.
+ astra.projector.delete(proj_id)
+ astra.algorithm.delete([alg_id, alg_id_rel])
+ astra.data2d.delete([vid, sid])
diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp
new file mode 100644
index 0000000..df13f31
--- /dev/null
+++ b/src/PluginAlgorithm.cpp
@@ -0,0 +1,294 @@
+/*
+-----------------------------------------------------------------------
+Copyright: 2010-2015, iMinds-Vision Lab, University of Antwerp
+ 2014-2015, CWI, Amsterdam
+
+Contact: astra@uantwerpen.be
+Website: http://sf.net/projects/astra-toolbox
+
+This file is part of the ASTRA Toolbox.
+
+
+The ASTRA Toolbox is free software: you can redistribute it and/or modify
+it under the terms of the GNU General Public License as published by
+the Free Software Foundation, either version 3 of the License, or
+(at your option) any later version.
+
+The ASTRA Toolbox is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+GNU General Public License for more details.
+
+You should have received a copy of the GNU General Public License
+along with the ASTRA Toolbox. If not, see <http://www.gnu.org/licenses/>.
+
+-----------------------------------------------------------------------
+$Id$
+*/
+
+#ifdef ASTRA_PYTHON
+
+#include "astra/PluginAlgorithm.h"
+#include <boost/algorithm/string.hpp>
+#include <boost/algorithm/string/split.hpp>
+#include <boost/lexical_cast.hpp>
+#include <iostream>
+#include <fstream>
+#include <string>
+
+namespace astra {
+
+CPluginAlgorithm::CPluginAlgorithm(PyObject* pyclass){
+ instance = PyObject_CallObject(pyclass, NULL);
+}
+
+CPluginAlgorithm::~CPluginAlgorithm(){
+ if(instance!=NULL){
+ Py_DECREF(instance);
+ instance = NULL;
+ }
+}
+
+bool CPluginAlgorithm::initialize(const Config& _cfg){
+ if(instance==NULL) return false;
+ PyObject *cfgDict = XMLNode2dict(_cfg.self);
+ PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict);
+ Py_DECREF(cfgDict);
+ if(retVal==NULL) return false;
+ m_bIsInitialized = true;
+ Py_DECREF(retVal);
+ return m_bIsInitialized;
+}
+
+void CPluginAlgorithm::run(int _iNrIterations){
+ if(instance==NULL) return;
+ PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations);
+ if(retVal==NULL) return;
+ Py_DECREF(retVal);
+}
+
+const char ps =
+#ifdef _WIN32
+ '\\';
+#else
+ '/';
+#endif
+
+std::vector<std::string> CPluginAlgorithmFactory::getPluginPathList(){
+ std::vector<std::string> list;
+ list.push_back("/etc/astra-toolbox");
+ PyObject *ret, *retb;
+ ret = PyObject_CallMethod(inspect,"getfile","O",astra);
+ if(ret!=NULL){
+ retb = PyObject_CallMethod(six,"b","O",ret);
+ Py_DECREF(ret);
+ if(retb!=NULL){
+ std::string astra_inst (PyBytes_AsString(retb));
+ Py_DECREF(retb);
+ ret = PyObject_CallMethod(ospath,"dirname","s",astra_inst.c_str());
+ if(ret!=NULL){
+ retb = PyObject_CallMethod(six,"b","O",ret);
+ Py_DECREF(ret);
+ if(retb!=NULL){
+ list.push_back(std::string(PyBytes_AsString(retb)));
+ Py_DECREF(retb);
+ }
+ }
+ }
+ }
+ ret = PyObject_CallMethod(ospath,"expanduser","s","~");
+ if(ret!=NULL){
+ retb = PyObject_CallMethod(six,"b","O",ret);
+ Py_DECREF(ret);
+ if(retb!=NULL){
+ list.push_back(std::string(PyBytes_AsString(retb)) + ps + ".astra-toolbox");
+ Py_DECREF(retb);
+ }
+ }
+ const char *envval = getenv("ASTRA_PLUGIN_PATH");
+ if(envval!=NULL){
+ list.push_back(std::string(envval));
+ }
+ return list;
+}
+
+CPluginAlgorithmFactory::CPluginAlgorithmFactory(){
+ Py_Initialize();
+ pluginDict = PyDict_New();
+ ospath = PyImport_ImportModule("os.path");
+ inspect = PyImport_ImportModule("inspect");
+ six = PyImport_ImportModule("six");
+ astra = PyImport_ImportModule("astra");
+ std::vector<std::string> fls = getPluginPathList();
+ std::vector<std::string> items;
+ for(unsigned int i=0;i<fls.size();i++){
+ std::ifstream fs ((fls[i]+ps+"plugins.txt").c_str());
+ if(!fs.is_open()) continue;
+ std::string line;
+ while (std::getline(fs,line)){
+ boost::split(items, line, boost::is_any_of(" "));
+ if(items.size()<2) continue;
+ PyObject *str = PyBytes_FromString(items[1].c_str());
+ PyDict_SetItemString(pluginDict,items[0].c_str(),str);
+ Py_DECREF(str);
+ }
+ fs.close();
+ }
+}
+
+CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){
+ if(pluginDict!=NULL){
+ Py_DECREF(pluginDict);
+ }
+}
+
+bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){
+ PyObject *str = PyBytes_FromString(className.c_str());
+ PyDict_SetItemString(pluginDict, name.c_str(), str);
+ Py_DECREF(str);
+ return true;
+}
+
+bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){
+ PyDict_SetItemString(pluginDict, name.c_str(), className);
+ return true;
+}
+
+PyObject * getClassFromString(std::string str){
+ std::vector<std::string> items;
+ boost::split(items, str, boost::is_any_of("."));
+ PyObject *pyclass = PyImport_ImportModule(items[0].c_str());
+ if(pyclass==NULL) return NULL;
+ PyObject *submod = pyclass;
+ for(unsigned int i=1;i<items.size();i++){
+ submod = PyObject_GetAttrString(submod,items[i].c_str());
+ Py_DECREF(pyclass);
+ pyclass = submod;
+ if(pyclass==NULL) return NULL;
+ }
+ return pyclass;
+}
+
+CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){
+ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
+ if(className==NULL) return NULL;
+ CPluginAlgorithm *alg = NULL;
+ if(PyBytes_Check(className)){
+ std::string str = std::string(PyBytes_AsString(className));
+ PyObject *pyclass = getClassFromString(str);
+ if(pyclass!=NULL){
+ alg = new CPluginAlgorithm(pyclass);
+ Py_DECREF(pyclass);
+ }
+ }else{
+ alg = new CPluginAlgorithm(className);
+ }
+ return alg;
+}
+
+PyObject * CPluginAlgorithmFactory::getRegistered(){
+ Py_INCREF(pluginDict);
+ return pluginDict;
+}
+
+std::string CPluginAlgorithmFactory::getHelp(std::string name){
+ PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
+ if(className==NULL) return "";
+ std::string str = std::string(PyBytes_AsString(className));
+ std::string ret = "";
+ PyObject *pyclass = getClassFromString(str);
+ if(pyclass==NULL) return "";
+ PyObject *module = PyImport_ImportModule("inspect");
+ if(module!=NULL){
+ PyObject *retVal = PyObject_CallMethod(module,"getdoc","O",pyclass);
+ if(retVal!=NULL){
+ PyObject *retb = PyObject_CallMethod(six,"b","O",retVal);
+ Py_DECREF(retVal);
+ if(retVal!=NULL){
+ ret = std::string(PyBytes_AsString(retb));
+ Py_DECREF(retb);
+ }
+ }
+ Py_DECREF(module);
+ }
+ Py_DECREF(pyclass);
+ return ret;
+}
+
+DEFINE_SINGLETON(CPluginAlgorithmFactory);
+
+#if PY_MAJOR_VERSION >= 3
+PyObject * pyStringFromString(std::string str){
+ return PyUnicode_FromString(str.c_str());
+}
+#else
+PyObject * pyStringFromString(std::string str){
+ return PyBytes_FromString(str.c_str());
+}
+#endif
+
+PyObject* stringToPythonValue(std::string str){
+ if(str.find(";")!=std::string::npos){
+ std::vector<std::string> rows, row;
+ boost::split(rows, str, boost::is_any_of(";"));
+ PyObject *mat = PyList_New(rows.size());
+ for(unsigned int i=0; i<rows.size(); i++){
+ boost::split(row, rows[i], boost::is_any_of(","));
+ PyObject *rowlist = PyList_New(row.size());
+ for(unsigned int j=0;j<row.size();j++){
+ PyList_SetItem(rowlist, j, PyFloat_FromDouble(boost::lexical_cast<double>(row[j])));
+ }
+ PyList_SetItem(mat, i, rowlist);
+ }
+ return mat;
+ }
+ if(str.find(",")!=std::string::npos){
+ std::vector<std::string> vec;
+ boost::split(vec, str, boost::is_any_of(","));
+ PyObject *veclist = PyList_New(vec.size());
+ for(unsigned int i=0;i<vec.size();i++){
+ PyList_SetItem(veclist, i, PyFloat_FromDouble(boost::lexical_cast<double>(vec[i])));
+ }
+ return veclist;
+ }
+ try{
+ return PyLong_FromLong(boost::lexical_cast<long>(str));
+ }catch(const boost::bad_lexical_cast &){
+ try{
+ return PyFloat_FromDouble(boost::lexical_cast<double>(str));
+ }catch(const boost::bad_lexical_cast &){
+ return pyStringFromString(str);
+ }
+ }
+}
+
+PyObject* XMLNode2dict(XMLNode node){
+ PyObject *dct = PyDict_New();
+ PyObject *opts = PyDict_New();
+ if(node.hasAttribute("type")){
+ PyObject *obj = pyStringFromString(node.getAttribute("type").c_str());
+ PyDict_SetItemString(dct, "type", obj);
+ Py_DECREF(obj);
+ }
+ std::list<XMLNode> nodes = node.getNodes();
+ std::list<XMLNode>::iterator it = nodes.begin();
+ while(it!=nodes.end()){
+ XMLNode subnode = *it;
+ if(subnode.getName()=="Option"){
+ PyObject *obj = stringToPythonValue(subnode.getAttribute("value"));
+ PyDict_SetItemString(opts, subnode.getAttribute("key").c_str(), obj);
+ Py_DECREF(obj);
+ }else{
+ PyObject *obj = stringToPythonValue(subnode.getContent());
+ PyDict_SetItemString(dct, subnode.getName().c_str(), obj);
+ Py_DECREF(obj);
+ }
+ ++it;
+ }
+ PyDict_SetItemString(dct, "options", opts);
+ Py_DECREF(opts);
+ return dct;
+}
+
+}
+#endif \ No newline at end of file