summaryrefslogtreecommitdiffstats
path: root/samples/python/s018_plugin.py
diff options
context:
space:
mode:
Diffstat (limited to 'samples/python/s018_plugin.py')
-rw-r--r--samples/python/s018_plugin.py34
1 files changed, 21 insertions, 13 deletions
diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py
index 31cca95..85b5486 100644
--- a/samples/python/s018_plugin.py
+++ b/samples/python/s018_plugin.py
@@ -30,30 +30,38 @@ import six
# Define the plugin class (has to subclass astra.plugin.base)
# Note that usually, these will be defined in a separate package/module
-class SIRTPlugin(astra.plugin.base):
- """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm.
+class LandweberPlugin(astra.plugin.base):
+ """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm.
Options:
- 'rel_factor': relaxation factor (optional)
+ 'Relaxation': relaxation factor (optional)
"""
# The astra_name variable defines the name to use to
# call the plugin from ASTRA
- astra_name = "SIRT-PLUGIN"
+ astra_name = "LANDWEBER-PLUGIN"
- def initialize(self,cfg, rel_factor = 1):
+ def initialize(self,cfg, Relaxation = 1):
self.W = astra.OpTomo(cfg['ProjectorId'])
self.vid = cfg['ReconstructionDataId']
self.sid = cfg['ProjectionDataId']
- self.rel = rel_factor
+ self.rel = Relaxation
def run(self, its):
v = astra.data2d.get_shared(self.vid)
s = astra.data2d.get_shared(self.sid)
+ tv = np.zeros(v.shape, dtype=np.float32)
+ ts = np.zeros(s.shape, dtype=np.float32)
W = self.W
for i in range(its):
- v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size
+ W.FP(v,out=ts)
+ ts -= s # ts = W*v - s
+
+ W.BP(ts,out=tv)
+ tv *= self.rel / s.size
+
+ v -= tv # v = v - rel * W'*(W*v-s) / s.size
if __name__=='__main__':
@@ -75,20 +83,20 @@ if __name__=='__main__':
# First we import the package that contains the plugin
import s018_plugin
# Then, we register the plugin class with ASTRA
- astra.plugin.register(s018_plugin.SIRTPlugin)
+ astra.plugin.register(s018_plugin.LandweberPlugin)
# Get a list of registered plugins
six.print_(astra.plugin.get_registered())
# To get help on a registered plugin, use get_help
- six.print_(astra.plugin.get_help('SIRT-PLUGIN'))
+ six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN'))
# Create data structures
sid = astra.data2d.create('-sino', proj_geom, sinogram)
vid = astra.data2d.create('-vol', vol_geom)
# Create config using plugin name
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
@@ -103,18 +111,18 @@ if __name__=='__main__':
rec = astra.data2d.get(vid)
# Options for the plugin go in cfg['option']
- cfg = astra.astra_dict('SIRT-PLUGIN')
+ cfg = astra.astra_dict('LANDWEBER-PLUGIN')
cfg['ProjectorId'] = proj_id
cfg['ProjectionDataId'] = sid
cfg['ReconstructionDataId'] = vid
cfg['option'] = {}
- cfg['option']['rel_factor'] = 1.5
+ cfg['option']['Relaxation'] = 1.5
alg_id_rel = astra.algorithm.create(cfg)
astra.algorithm.run(alg_id_rel, 100)
rec_rel = astra.data2d.get(vid)
# We can also use OpTomo to call the plugin
- rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5})
+ rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5})
import pylab as pl
pl.gray()