diff options
-rw-r--r-- | python/astra/optomo.py | 96 | ||||
-rw-r--r-- | samples/python/s018_plugin.py | 34 |
2 files changed, 87 insertions, 43 deletions
diff --git a/python/astra/optomo.py b/python/astra/optomo.py index 5a92998..dde719e 100644 --- a/python/astra/optomo.py +++ b/python/astra/optomo.py @@ -111,21 +111,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): :param v: Volume to forward project. :type v: :class:`numpy.ndarray` """ - v = self.__checkArray(v, self.vshape) - vid = self.data_mod.link('-vol',self.vg,v) - s = np.zeros(self.sshape,dtype=np.float32) - sid = self.data_mod.link('-sino',self.pg,s) - - cfg = creators.astra_dict('FP'+self.appendString) - cfg['ProjectionDataId'] = sid - cfg['VolumeDataId'] = vid - cfg['ProjectorId'] = self.proj_id - fp_id = algorithm.create(cfg) - algorithm.run(fp_id) - - algorithm.delete(fp_id) - self.data_mod.delete([vid,sid]) - return s.ravel() + return self.FP(v, out=None).ravel() def rmatvec(self,s): """Implements the transpose operator. @@ -133,21 +119,7 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): :param s: The projection data. :type s: :class:`numpy.ndarray` """ - s = self.__checkArray(s, self.sshape) - sid = self.data_mod.link('-sino',self.pg,s) - v = np.zeros(self.vshape,dtype=np.float32) - vid = self.data_mod.link('-vol',self.vg,v) - - cfg = creators.astra_dict('BP'+self.appendString) - cfg['ProjectionDataId'] = sid - cfg['ReconstructionDataId'] = vid - cfg['ProjectorId'] = self.proj_id - bp_id = algorithm.create(cfg) - algorithm.run(bp_id) - - algorithm.delete(bp_id) - self.data_mod.delete([vid,sid]) - return v.ravel() + return self.BP(s, out=None).ravel() def __mul__(self,v): """Provides easy forward operator by *. @@ -189,6 +161,70 @@ class OpTomo(scipy.sparse.linalg.LinearOperator): self.data_mod.delete([vid,sid]) return v + def FP(self,v,out=None): + """Perform forward projection. + + Output must have the right 2D/3D shape. Input may also be flattened. + + Output must also be contiguous and float32. This isn't required for the + input, but it is more efficient if it is. + + :param v: Volume to forward project. + :type v: :class:`numpy.ndarray` + :param out: Array to store result in. + :type out: :class:`numpy.ndarray` + """ + + v = self.__checkArray(v, self.vshape) + vid = self.data_mod.link('-vol',self.vg,v) + if out is None: + out = np.zeros(self.sshape,dtype=np.float32) + sid = self.data_mod.link('-sino',self.pg,out) + + cfg = creators.astra_dict('FP'+self.appendString) + cfg['ProjectionDataId'] = sid + cfg['VolumeDataId'] = vid + cfg['ProjectorId'] = self.proj_id + fp_id = algorithm.create(cfg) + algorithm.run(fp_id) + + algorithm.delete(fp_id) + self.data_mod.delete([vid,sid]) + return out + + def BP(self,s,out=None): + """Perform backprojection. + + Output must have the right 2D/3D shape. Input may also be flattened. + + Output must also be contiguous and float32. This isn't required for the + input, but it is more efficient if it is. + + :param : The projection data. + :type s: :class:`numpy.ndarray` + :param out: Array to store result in. + :type out: :class:`numpy.ndarray` + """ + s = self.__checkArray(s, self.sshape) + sid = self.data_mod.link('-sino',self.pg,s) + if out is None: + out = np.zeros(self.vshape,dtype=np.float32) + vid = self.data_mod.link('-vol',self.vg,out) + + cfg = creators.astra_dict('BP'+self.appendString) + cfg['ProjectionDataId'] = sid + cfg['ReconstructionDataId'] = vid + cfg['ProjectorId'] = self.proj_id + bp_id = algorithm.create(cfg) + algorithm.run(bp_id) + + algorithm.delete(bp_id) + self.data_mod.delete([vid,sid]) + return out + + + + class OpTomoTranspose(scipy.sparse.linalg.LinearOperator): """This object provides the transpose operation (``.T``) of the OpTomo object. diff --git a/samples/python/s018_plugin.py b/samples/python/s018_plugin.py index 31cca95..85b5486 100644 --- a/samples/python/s018_plugin.py +++ b/samples/python/s018_plugin.py @@ -30,30 +30,38 @@ import six # Define the plugin class (has to subclass astra.plugin.base) # Note that usually, these will be defined in a separate package/module -class SIRTPlugin(astra.plugin.base): - """Example of an ASTRA plugin class, implementing a simple 2D SIRT algorithm. +class LandweberPlugin(astra.plugin.base): + """Example of an ASTRA plugin class, implementing a simple 2D Landweber algorithm. Options: - 'rel_factor': relaxation factor (optional) + 'Relaxation': relaxation factor (optional) """ # The astra_name variable defines the name to use to # call the plugin from ASTRA - astra_name = "SIRT-PLUGIN" + astra_name = "LANDWEBER-PLUGIN" - def initialize(self,cfg, rel_factor = 1): + def initialize(self,cfg, Relaxation = 1): self.W = astra.OpTomo(cfg['ProjectorId']) self.vid = cfg['ReconstructionDataId'] self.sid = cfg['ProjectionDataId'] - self.rel = rel_factor + self.rel = Relaxation def run(self, its): v = astra.data2d.get_shared(self.vid) s = astra.data2d.get_shared(self.sid) + tv = np.zeros(v.shape, dtype=np.float32) + ts = np.zeros(s.shape, dtype=np.float32) W = self.W for i in range(its): - v[:] += self.rel*(W.T*(s - (W*v).reshape(s.shape))).reshape(v.shape)/s.size + W.FP(v,out=ts) + ts -= s # ts = W*v - s + + W.BP(ts,out=tv) + tv *= self.rel / s.size + + v -= tv # v = v - rel * W'*(W*v-s) / s.size if __name__=='__main__': @@ -75,20 +83,20 @@ if __name__=='__main__': # First we import the package that contains the plugin import s018_plugin # Then, we register the plugin class with ASTRA - astra.plugin.register(s018_plugin.SIRTPlugin) + astra.plugin.register(s018_plugin.LandweberPlugin) # Get a list of registered plugins six.print_(astra.plugin.get_registered()) # To get help on a registered plugin, use get_help - six.print_(astra.plugin.get_help('SIRT-PLUGIN')) + six.print_(astra.plugin.get_help('LANDWEBER-PLUGIN')) # Create data structures sid = astra.data2d.create('-sino', proj_geom, sinogram) vid = astra.data2d.create('-vol', vol_geom) # Create config using plugin name - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid @@ -103,18 +111,18 @@ if __name__=='__main__': rec = astra.data2d.get(vid) # Options for the plugin go in cfg['option'] - cfg = astra.astra_dict('SIRT-PLUGIN') + cfg = astra.astra_dict('LANDWEBER-PLUGIN') cfg['ProjectorId'] = proj_id cfg['ProjectionDataId'] = sid cfg['ReconstructionDataId'] = vid cfg['option'] = {} - cfg['option']['rel_factor'] = 1.5 + cfg['option']['Relaxation'] = 1.5 alg_id_rel = astra.algorithm.create(cfg) astra.algorithm.run(alg_id_rel, 100) rec_rel = astra.data2d.get(vid) # We can also use OpTomo to call the plugin - rec_op = W.reconstruct('SIRT-PLUGIN', sinogram, 100, extraOptions={'rel_factor':1.5}) + rec_op = W.reconstruct('LANDWEBER-PLUGIN', sinogram, 100, extraOptions={'Relaxation':1.5}) import pylab as pl pl.gray() |