summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/astra/PluginAlgorithm.h3
-rw-r--r--matlab/mex/astra_mex_plugin_c.cpp23
-rw-r--r--python/astra/plugin.py71
-rw-r--r--python/astra/plugin_c.pyx14
-rw-r--r--samples/python/s018_plugin.py23
-rw-r--r--src/PluginAlgorithm.cpp95
6 files changed, 138 insertions, 91 deletions
diff --git a/include/astra/PluginAlgorithm.h b/include/astra/PluginAlgorithm.h
index a82c579..b56228e 100644
--- a/include/astra/PluginAlgorithm.h
+++ b/include/astra/PluginAlgorithm.h
@@ -64,9 +64,12 @@ public:
CPluginAlgorithm * getPlugin(std::string name);
bool registerPlugin(std::string name, std::string className);
+ bool registerPlugin(std::string className);
bool registerPluginClass(std::string name, PyObject * className);
+ bool registerPluginClass(PyObject * className);
PyObject * getRegistered();
+ std::map<std::string, std::string> getRegisteredMap();
std::string getHelp(std::string name);
diff --git a/matlab/mex/astra_mex_plugin_c.cpp b/matlab/mex/astra_mex_plugin_c.cpp
index 2d9b9a0..177fcf4 100644
--- a/matlab/mex/astra_mex_plugin_c.cpp
+++ b/matlab/mex/astra_mex_plugin_c.cpp
@@ -37,9 +37,6 @@ $Id$
#include "astra/PluginAlgorithm.h"
-#include "Python.h"
-#include "bytesobject.h"
-
using namespace std;
using namespace astra;
@@ -52,29 +49,25 @@ using namespace astra;
void astra_mex_plugin_get_registered(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
- PyObject *dict = fact->getRegistered();
- PyObject *key, *value;
- Py_ssize_t pos = 0;
- while (PyDict_Next(dict, &pos, &key, &value)) {
- mexPrintf("%s: %s\n",PyBytes_AsString(key),PyBytes_AsString(value));
+ std::map<std::string, std::string> mp = fact->getRegisteredMap();
+ for(std::map<std::string,std::string>::iterator it=mp.begin();it!=mp.end();it++){
+ mexPrintf("%s: %s\n",it->first.c_str(), it->second.c_str());
}
- Py_DECREF(dict);
}
//-----------------------------------------------------------------------------------------
-/** astra_mex_plugin('register', name, class_name);
+/** astra_mex_plugin('register', class_name);
*
* Register plugin.
*/
void astra_mex_plugin_register(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[])
{
- if (3 <= nrhs) {
- string name = mexToString(prhs[1]);
- string class_name = mexToString(prhs[2]);
+ if (2 <= nrhs) {
+ string class_name = mexToString(prhs[1]);
astra::CPluginAlgorithmFactory *fact = astra::CPluginAlgorithmFactory::getSingletonPtr();
- fact->registerPlugin(name, class_name);
+ fact->registerPlugin(class_name);
}else{
- mexPrintf("astra_mex_plugin('register', name, class_name);\n");
+ mexPrintf("astra_mex_plugin('register', class_name);\n");
}
}
diff --git a/python/astra/plugin.py b/python/astra/plugin.py
index f8fc3bd..4b32e6e 100644
--- a/python/astra/plugin.py
+++ b/python/astra/plugin.py
@@ -32,60 +32,47 @@ import traceback
class base(object):
def astra_init(self, cfg):
- try:
- args, varargs, varkw, defaults = inspect.getargspec(self.initialize)
- if not defaults is None:
- nopt = len(defaults)
- else:
- nopt = 0
- if nopt>0:
- req = args[2:-nopt]
- opt = args[-nopt:]
- else:
- req = args[2:]
- opt = []
+ args, varargs, varkw, defaults = inspect.getargspec(self.initialize)
+ if not defaults is None:
+ nopt = len(defaults)
+ else:
+ nopt = 0
+ if nopt>0:
+ req = args[2:-nopt]
+ opt = args[-nopt:]
+ else:
+ req = args[2:]
+ opt = []
- try:
- optDict = cfg['options']
- except KeyError:
- optDict = {}
+ try:
+ optDict = cfg['options']
+ except KeyError:
+ optDict = {}
- cfgKeys = set(optDict.keys())
- reqKeys = set(req)
- optKeys = set(opt)
+ cfgKeys = set(optDict.keys())
+ reqKeys = set(req)
+ optKeys = set(opt)
- if not reqKeys.issubset(cfgKeys):
- for key in reqKeys.difference(cfgKeys):
- log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified")
- raise ValueError("Missing required options")
+ if not reqKeys.issubset(cfgKeys):
+ for key in reqKeys.difference(cfgKeys):
+ log.error("Required option '" + key + "' for plugin '" + self.__class__.__name__ + "' not specified")
+ raise ValueError("Missing required options")
- if not cfgKeys.issubset(reqKeys | optKeys):
- log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys))))
+ if not cfgKeys.issubset(reqKeys | optKeys):
+ log.warn(self.__class__.__name__ + ": unused configuration option: " + str(list(cfgKeys.difference(reqKeys | optKeys))))
- args = [optDict[k] for k in req]
- kwargs = dict((k,optDict[k]) for k in opt if k in optDict)
- self.initialize(cfg, *args, **kwargs)
- except Exception:
- log.error(traceback.format_exc().replace("%","%%"))
- raise
+ args = [optDict[k] for k in req]
+ kwargs = dict((k,optDict[k]) for k in opt if k in optDict)
+ self.initialize(cfg, *args, **kwargs)
- def astra_run(self, its):
- try:
- self.run(its)
- except Exception:
- log.error(traceback.format_exc().replace("%","%%"))
- raise
-
-def register(name, className):
+def register(className):
"""Register plugin with ASTRA.
- :param name: Plugin name to register
- :type name: :class:`str`
:param className: Class name or class object to register
:type className: :class:`str` or :class:`class`
"""
- p.register(name,className)
+ p.register(className)
def get_registered():
"""Get dictionary of registered plugins.
diff --git a/python/astra/plugin_c.pyx b/python/astra/plugin_c.pyx
index 91b3cd5..8d6816b 100644
--- a/python/astra/plugin_c.pyx
+++ b/python/astra/plugin_c.pyx
@@ -38,7 +38,9 @@ from . import utils
cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
cdef cppclass CPluginAlgorithmFactory:
+ bool registerPlugin(string className)
bool registerPlugin(string name, string className)
+ bool registerPluginClass(object className)
bool registerPluginClass(string name, object className)
object getRegistered()
string getHelp(string name)
@@ -46,11 +48,17 @@ cdef extern from "astra/PluginAlgorithm.h" namespace "astra":
cdef extern from "astra/PluginAlgorithm.h" namespace "astra::CPluginAlgorithmFactory":
cdef CPluginAlgorithmFactory* getSingletonPtr()
-def register(name, className):
+def register(className, name=None):
if inspect.isclass(className):
- fact.registerPluginClass(six.b(name), className)
+ if name==None:
+ fact.registerPluginClass(className)
+ else:
+ fact.registerPluginClass(six.b(name), className)
else:
- fact.registerPlugin(six.b(name), six.b(className))
+ if name==None:
+ fact.registerPlugin(six.b(className))
+ else:
+ fact.registerPlugin(six.b(name), six.b(className))
def get_registered():
return fact.getRegistered()
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 90e09ac..31cca95 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -38,6 +38,10 @@ class SIRTPlugin(astra.plugin.base):
'rel_factor': relaxation factor (optional)
"""
+ # The astra_name variable defines the name to use to
+ # call the plugin from ASTRA
+ astra_name = "SIRT-PLUGIN"
+
def initialize(self,cfg, rel_factor = 1):
self.W = astra.OpTomo(cfg['ProjectorId'])
self.vid = cfg['ReconstructionDataId']
@@ -68,18 +72,13 @@ if __name__=='__main__':
sinogram = sinogram.reshape([180, 384])
# Register the plugin with ASTRA
- # A default set of plugins to load can be defined in:
- # - /etc/astra-toolbox/plugins.txt
- # - [ASTRA_INSTALL_PATH]/python/astra/plugins.txt
- # - [USER_HOME_PATH]/.astra-toolbox/plugins.txt
- # - [ASTRA_PLUGIN_PATH environment variable]/plugins.txt
- # In these files, create a separate line for each plugin with:
- # [PLUGIN_ASTRA_NAME] [FULL_PLUGIN_CLASS]
- #
- # So in this case, it would be a line:
- # SIRT-PLUGIN s018_plugin.SIRTPlugin
- #
- astra.plugin.register('SIRT-PLUGIN','s018_plugin.SIRTPlugin')
+ # First we import the package that contains the plugin
+ import s018_plugin
+ # Then, we register the plugin class with ASTRA
+ astra.plugin.register(s018_plugin.SIRTPlugin)
+
+ # Get a list of registered plugins
+ six.print_(astra.plugin.get_registered())
# To get help on a registered plugin, use get_help
six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
diff --git a/src/PluginAlgorithm.cpp b/src/PluginAlgorithm.cpp
index d6cf731..7f7ff61 100644
--- a/src/PluginAlgorithm.cpp
+++ b/src/PluginAlgorithm.cpp
@@ -100,7 +100,10 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){
PyObject *cfgDict = XMLNode2dict(_cfg.self);
PyObject *retVal = PyObject_CallMethod(instance, "astra_init", "O",cfgDict);
Py_DECREF(cfgDict);
- if(retVal==NULL) return false;
+ if(retVal==NULL){
+ logPythonError();
+ return false;
+ }
m_bIsInitialized = true;
Py_DECREF(retVal);
return m_bIsInitialized;
@@ -108,8 +111,11 @@ bool CPluginAlgorithm::initialize(const Config& _cfg){
void CPluginAlgorithm::run(int _iNrIterations){
if(instance==NULL) return;
- PyObject *retVal = PyObject_CallMethod(instance, "astra_run", "i",_iNrIterations);
- if(retVal==NULL) return;
+ PyObject *retVal = PyObject_CallMethod(instance, "run", "i",_iNrIterations);
+ if(retVal==NULL){
+ logPythonError();
+ return;
+ }
Py_DECREF(retVal);
}
@@ -157,18 +163,6 @@ CPluginAlgorithmFactory::~CPluginAlgorithmFactory(){
if(six!=NULL) Py_DECREF(six);
}
-bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){
- PyObject *str = PyBytes_FromString(className.c_str());
- PyDict_SetItemString(pluginDict, name.c_str(), str);
- Py_DECREF(str);
- return true;
-}
-
-bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){
- PyDict_SetItemString(pluginDict, name.c_str(), className);
- return true;
-}
-
PyObject * getClassFromString(std::string str){
std::vector<std::string> items;
boost::split(items, str, boost::is_any_of("."));
@@ -190,6 +184,43 @@ PyObject * getClassFromString(std::string str){
return pyclass;
}
+bool CPluginAlgorithmFactory::registerPlugin(std::string name, std::string className){
+ PyObject *str = PyBytes_FromString(className.c_str());
+ PyDict_SetItemString(pluginDict, name.c_str(), str);
+ Py_DECREF(str);
+ return true;
+}
+
+bool CPluginAlgorithmFactory::registerPlugin(std::string className){
+ PyObject *pyclass = getClassFromString(className);
+ if(pyclass==NULL) return false;
+ bool ret = registerPluginClass(pyclass);
+ Py_DECREF(pyclass);
+ return ret;
+}
+
+bool CPluginAlgorithmFactory::registerPluginClass(std::string name, PyObject * className){
+ PyDict_SetItemString(pluginDict, name.c_str(), className);
+ return true;
+}
+
+bool CPluginAlgorithmFactory::registerPluginClass(PyObject * className){
+ PyObject *astra_name = PyObject_GetAttrString(className,"astra_name");
+ if(astra_name==NULL){
+ logPythonError();
+ return false;
+ }
+ PyObject *retb = PyObject_CallMethod(six,"b","O",astra_name);
+ if(retb!=NULL){
+ PyDict_SetItemString(pluginDict,PyBytes_AsString(retb),className);
+ Py_DECREF(retb);
+ }else{
+ logPythonError();
+ }
+ Py_DECREF(astra_name);
+ return true;
+}
+
CPluginAlgorithm * CPluginAlgorithmFactory::getPlugin(std::string name){
PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
if(className==NULL) return NULL;
@@ -212,12 +243,34 @@ PyObject * CPluginAlgorithmFactory::getRegistered(){
return pluginDict;
}
+std::map<std::string, std::string> CPluginAlgorithmFactory::getRegisteredMap(){
+ std::map<std::string, std::string> ret;
+ PyObject *key, *value;
+ Py_ssize_t pos = 0;
+ while (PyDict_Next(pluginDict, &pos, &key, &value)) {
+ PyObject * keyb = PyObject_Bytes(key);
+ PyObject * valb = PyObject_Bytes(value);
+ ret[PyBytes_AsString(keyb)] = PyBytes_AsString(valb);
+ Py_DECREF(keyb);
+ Py_DECREF(valb);
+ }
+ return ret;
+}
+
std::string CPluginAlgorithmFactory::getHelp(std::string name){
PyObject *className = PyDict_GetItemString(pluginDict, name.c_str());
- if(className==NULL) return "";
- std::string str = std::string(PyBytes_AsString(className));
+ if(className==NULL){
+ ASTRA_ERROR("Plugin %s not found!",name.c_str());
+ return "";
+ }
std::string ret = "";
- PyObject *pyclass = getClassFromString(str);
+ PyObject *pyclass;
+ if(PyBytes_Check(className)){
+ std::string str = std::string(PyBytes_AsString(className));
+ pyclass = getClassFromString(str);
+ }else{
+ pyclass = className;
+ }
if(pyclass==NULL) return "";
if(inspect!=NULL && six!=NULL){
PyObject *retVal = PyObject_CallMethod(inspect,"getdoc","O",pyclass);
@@ -228,9 +281,13 @@ std::string CPluginAlgorithmFactory::getHelp(std::string name){
ret = std::string(PyBytes_AsString(retb));
Py_DECREF(retb);
}
+ }else{
+ logPythonError();
}
}
- Py_DECREF(pyclass);
+ if(PyBytes_Check(className)){
+ Py_DECREF(pyclass);
+ }
return ret;
}