From 0174c3226d9d540d0adc874436bf5b8f7a80f9b5 Mon Sep 17 00:00:00 2001
From: Edoardo Pasca <edo.paskino@gmail.com>
Date: Tue, 31 Oct 2017 12:49:10 +0000
Subject: Added Ordered Subsets

---
 .../ccpi/reconstruction/FISTAReconstructor.py      | 186 +++++++++++++++++++--
 1 file changed, 173 insertions(+), 13 deletions(-)

(limited to 'src')

diff --git a/src/Python/ccpi/reconstruction/FISTAReconstructor.py b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
index b8e1027..4f1709c 100644
--- a/src/Python/ccpi/reconstruction/FISTAReconstructor.py
+++ b/src/Python/ccpi/reconstruction/FISTAReconstructor.py
@@ -430,7 +430,13 @@ class FISTAReconstructor():
 
     # prepareForIteration
 
-    def iterate(self, Xin=None):
+    def iterate (self, Xin=None):
+        if self.getParameter('subset') == 0:
+            return self.iterateStandard(Xin)
+        else:
+            return self.iterateOrderedSubsets(Xin)
+        
+    def iterateStandard(self, Xin=None):
         print ("FISTA Reconstructor: iterate")
         
         if Xin is None:    
@@ -613,14 +619,14 @@ class FISTAReconstructor():
     def updateLoop(self, i, X, X_old, r_old, t, t_old):
         print ("FISTA Reconstructor: update loop")
         lambdaR_L1 = self.getParameter('ring_lambda_R_L1')
-        if lambdaR_L1 > 0:
-            self.r = numpy.max(
-                numpy.abs(self.r) - lambdaR_L1 , 0) * \
-                numpy.sign(self.r)
+            
         t = (1 + numpy.sqrt(1 + 4 * t**2))/2
         X_t = X + (((t_old -1)/t) * (X - X_old))
 
         if lambdaR_L1 > 0:
+            self.r = numpy.max(
+                numpy.abs(self.r) - lambdaR_L1 , 0) * \
+                numpy.sign(self.r)
             self.r_x = self.r + \
                              (((t_old-1)/t) * (self.r - r_old))
 
@@ -636,8 +642,8 @@ class FISTAReconstructor():
             print (string.format(i,Resid_error[i], self.objective[i]))
         return (X , X_t, t)
 
-    def os_iterate(self, Xin=None):
-        print ("FISTA Reconstructor: iterate")
+    def iterateOS(self, Xin=None):
+        print ("FISTA Reconstructor: Ordered Subsets iterate")
         
         if Xin is None:    
             if self.getParameter('initialize'):
@@ -653,9 +659,163 @@ class FISTAReconstructor():
         X_t = X.copy()
 
         # some useful constants
-        proj_geom , vol_geom, sino , \
-          SlicesZ, weights , alpha_ring ,\
-          lambdaR_L1 , L_const = self.getParameter(
-            ['projector_geometry' , 'output_geometry',
-             'input_sinogram', 'SlicesZ' ,  'weights', 'ring_alpha' ,
-             'ring_lambda_R_L1', 'Lipschitz_constant'])
+        proj_geom ,    vol_geom, sino , \
+          SlicesZ,     weights , alpha_ring ,\
+          lambdaR_L1 , L_const , iterFISTA         = self.getParameter(
+            ['projector_geometry' , 'output_geometry', 'input_sinogram',
+             'SlicesZ' ,            'weights',         'ring_alpha' ,
+             'ring_lambda_R_L1',    'Lipschitz_constant',
+             'number_of_iterations'])
+
+        
+        # errors vector (if the ground truth is given)
+        Resid_error = numpy.zeros((iterFISTA));
+        # objective function values vector
+        objective = numpy.zeros((iterFISTA)); 
+
+          
+        t = 1
+
+        ## additional for 
+        proj_geomSUB = proj_geom.copy()
+        self.residual2 = numpy.zeros(numpy.shape(sino))
+        residual2 = self.residual2
+        sino_updt_FULL = self.residual.copy()
+        r_x = self.r.copy()
+
+        print ("starting iterations")
+        ##    % Outer FISTA iterations loop
+        for i in range(fistaRecon.getParameter('number_of_iterations')):
+            # With OS approach it becomes trickier to correlate independent
+            # subsets, hence additional work is required one solution is to
+            # work with a full sinogram at times
+
+            r_old = self.r.copy()
+            t_old = t
+            SlicesZ, anglesNumb, Detectors = \
+                        numpy.shape(self.getParameter('input_sinogram'))        ## https://github.com/vais-ral/CCPi-FISTA_Reconstruction/issues/4
+            if (i > 1 and lambdaR_L1 > 0) :
+                for kkk in range(anglesNumb):
+                     
+                     residual2[:,kkk,:] = (weights[:,kkk,:]).squeeze() * \
+                                           ((sino_updt_FULL[:,kkk,:]).squeeze() - \
+                                            (sino[:,kkk,:]).squeeze() -\
+                                            (alpha_ring * r_x)
+                                            )
+                
+                vec = self.residual.sum(axis = 1)
+                #if SlicesZ > 1:
+                #    vec = vec[:,1,:] # 1 or 0?
+                r_x = self.r_x
+                # update ring variable
+                self.r = (r_x - (1./L_const) * vec).copy()
+
+            # subset loop
+            counterInd = 1
+            geometry_type = self.getParameter('projector_geometry')['type']
+            angles = self.getParameter('projector_geometry')['ProjectionAngles']
+
+            for ss in range(self.getParameter('subsets')):
+                print ("Subset {0}".format(ss))
+                X_old = X.copy()
+                t_old = t
+                
+                # the number of projections per subset
+                numProjSub = self.getParameter('os_bins')[ss]
+                CurrSubIndices = self.getParameter('os_indices')\
+                                 [counterInd:counterInd+numProjSub]
+                #print ("Len CurrSubIndices {0}".format(numProjSub))
+                mask = numpy.zeros(numpy.shape(angles), dtype=bool)
+                cc = 0
+                for j in range(len(CurrSubIndices)):
+                    mask[int(CurrSubIndices[j])] = True
+                proj_geomSUB['ProjectionAngles'] = angles[mask]
+
+                shape = list(numpy.shape(self.getParameter('input_sinogram')))
+                shape[1] = numProjSub
+                sino_updt_Sub = numpy.zeros(shape)
+                if geometry_type == 'parallel' or \
+                   geometry_type == 'fanflat' or \
+                   geometry_type == 'fanflat_vec' :
+
+                    for kkk in range(SlicesZ):
+                        sino_id, sinoT = astra.creators.create_sino3d_gpu (
+                            X_t[kkk:kkk+1] , proj_geomSUB, vol_geom)
+                        sino_updt_Sub[kkk] = sinoT.T.copy()
+                        
+                else:
+                    # for 3D geometry (watch the GPU memory overflow in
+                    # ASTRA < 1.8)
+                    sino_id, sino_updt_Sub = \
+                         astra.creators.create_sino3d_gpu (X_t, proj_geomSUB, vol_geom)
+                    
+                astra.matlab.data3d('delete', sino_id)
+        
+                
+                if lambdaR_L1 > 0 :
+                    ## RING REMOVAL
+                    print ("ring removal")
+                    residualSub = self.ringRemovalOrderedSubsets(sino_updt_Sub,
+                                                   sino_updt_FULL)
+                else:
+                    #PWLS model
+                    print ("PWLS model")
+                    residualSub = weights[:,CurrSubIndices,:] * \
+                                  ( sino_updt_Sub - \
+                                    sino[:,CurrSubIndices,:].squeeze() )
+                    objective[i] = 0.5 * numpy.linalg.norm(residualSub)
+
+                # projection/backprojection routine
+                if geometry_type == 'parallel' or \
+                   geometry_type == 'fanflat' or \
+                   geometry_type == 'fanflat_vec' :
+                    # if geometry is 2D use slice-by-slice projection-backprojection
+                    # routine
+                    x_temp = numpy.zeros(numpy.shape(X), dtype=numpy.float32)
+                    for kkk in range(SlicesZ):
+                        
+                        x_id, x_temp[kkk] = \
+                                 astra.creators.create_backprojection3d_gpu(
+                                     residualSub[kkk:kkk+1],
+                                     proj_geomSUB, vol_geom)
+                        
+                else:
+                    x_id, x_temp = \
+                          astra.creators.create_backprojection3d_gpu(
+                              residualSub, proj_geomSUB, vol_geom)
+
+                astra.matlab.data3d('delete', x_id)
+                X = X_t - (1/L_const) * x_temp
+                ## REGULARIZATION
+                X = self.regularize(X)
+            
+            # FINAL
+            ## Update Loop
+            X , X_t, t = self.updateLoop(i, X, X_old, r_old, t, t_old)
+            self.setParameter(output_volume=X)
+            counterInd = counterInd + numProjSub
+
+        return X
+    
+    def ringRemovalOrderedSubsets(self, sino_updt_Sub, sino_updt_FULL):
+        residual = self.residual
+        r_x = self.r_x
+        weights , alpha_ring , sino = \
+                self.getParameter( ['weights', 'ring_alpha', 'input_sinogram'])
+        numProjSub = self.getParameter('os_bins')[ss]
+        CurrSubIndices = self.getParameter('os_indices')\
+                         [counterInd:counterInd+numProjSub]
+        residualSub = numpy.zeros(shape)
+
+        for kkk in range(numProjSub):
+            #print ("ring removal indC ... {0}".format(kkk))
+            indC = int(CurrSubIndices[kkk])
+            residualSub[:,kkk,:] = weights[:,indC,:].squeeze() * \
+                (sino_updt_Sub[:,kkk,:].squeeze() - \
+                sino[:,indC,:].squeeze() - alpha_ring * r_x)
+            # filling the full sinogram
+            sino_updt_FULL[:,indC,:] = sino_updt_Sub[:,kkk,:].squeeze()
+
+        return residualSub
+
+
-- 
cgit v1.2.3