diff options
-rw-r--r-- | tests/python/test_rec_scaling.py | 81 |
1 files changed, 76 insertions, 5 deletions
diff --git a/tests/python/test_rec_scaling.py b/tests/python/test_rec_scaling.py index 1bd3267..621fd8a 100644 --- a/tests/python/test_rec_scaling.py +++ b/tests/python/test_rec_scaling.py @@ -33,12 +33,46 @@ def ProjectionGeometries(type): for dU in [0.8, 1.0]: for dV in [0.8, 1.0]: yield astra.create_proj_geom('parallel3d', dU, dV, 128, 128, np.linspace(0,np.pi,180,False)) + elif type == 'parallel3d_vec': + for j in range(10): + Vectors = np.zeros([180,12]) + wu = 0.6 + 0.8 * np.random.random() + wv = 0.6 + 0.8 * np.random.random() + for i in range(Vectors.shape[0]): + l = 0.6 + 0.8 * np.random.random() + angle1 = 2*np.pi*np.random.random() + angle2 = angle1 + 0.5 * np.random.random() + angle3 = 0.1*np.pi*np.random.random() + detc = 10 * np.random.random(size=3) + detu = [ math.cos(angle1) * wu, math.sin(angle1) * wu, 0 ] + detv = [ -math.sin(angle1) * math.sin(angle3) * wv, math.cos(angle1) * math.sin(angle3) * wv, math.cos(angle3) * wv ] + ray = [ math.sin(angle2) * l, -math.cos(angle2) * l, 0 ] + Vectors[i, :] = [ ray[0], ray[1], ray[2], detc[0], detc[1], detc[2], detu[0], detu[1], detu[2], detv[0], detv[1], detv[2] ] + pg = astra.create_proj_geom('parallel3d_vec', 128, 128, Vectors) + yield pg elif type == 'cone': for dU in [0.8, 1.0]: for dV in [0.8, 1.0]: for src in [500, 1000]: for det in [0, 250]: yield astra.create_proj_geom('cone', dU, dV, 128, 128, np.linspace(0,2*np.pi,180,False), src, det) + elif type == 'cone_vec': + for j in range(10): + Vectors = np.zeros([180,12]) + wu = 0.6 + 0.8 * np.random.random() + wv = 0.6 + 0.8 * np.random.random() + for i in range(Vectors.shape[0]): + l = 256 * (0.5 * np.random.random()) + angle1 = 2*np.pi*np.random.random() + angle2 = angle1 + 0.5 * np.random.random() + angle3 = 0.1*np.pi*np.random.random() + detc = 10 * np.random.random(size=3) + detu = [ math.cos(angle1) * wu, math.sin(angle1) * wu, 0 ] + detv = [ -math.sin(angle1) * math.sin(angle3) * wv, math.cos(angle1) * math.sin(angle3) * wv, math.cos(angle3) * wv ] + src = [ math.sin(angle2) * l, -math.cos(angle2) * l, 0 ] + Vectors[i, :] = [ src[0], src[1], src[2], detc[0], detc[1], detc[2], detu[0], detu[1], detu[2], detv[0], detv[1], detv[2] ] + pg = astra.create_proj_geom('parallel3d_vec', 128, 128, Vectors) + yield pg class TestRecScale(unittest.TestCase): @@ -92,13 +126,44 @@ class TestRecScale(unittest.TestCase): pylab.show() self.assertTrue(abs(val-1.0) < TOL) + def single_test_adjoint3D(self, geom_type, proj_type): + for vg in VolumeGeometries(True, True): + for pg in ProjectionGeometries(geom_type): + for i in range(5): + X = np.random.random(astra.geom_size(vg)) + Y = np.random.random(astra.geom_size(pg)) + proj_id, fX = astra.create_sino3d_gpu(X, pg, vg) + bp_id, fTY = astra.create_backprojection3d_gpu(Y, pg, vg) + + astra.data3d.delete([proj_id, bp_id]) + + da = np.dot(fX.ravel(), Y.ravel()) + db = np.dot(X.ravel(), fTY.ravel()) + m = np.abs(da - db) + TOL = 1e-1 + if m / da >= TOL: + print(vg) + print(pg) + print(m/da, da/db, da, db) + self.assertTrue(m / da < TOL) + + + + __combinations = { - 'parallel': [ 'line', 'linear', 'distance_driven', 'strip', 'cuda' ], - 'fanflat': [ 'line_fanflat', 'strip_fanflat', 'cuda' ], - 'parallel3d': [ 'cuda3d' ], - 'cone': [ 'cuda3d' ], - } + 'parallel': [ 'line', 'linear', 'distance_driven', 'strip', 'cuda' ], + 'fanflat': [ 'line_fanflat', 'strip_fanflat', 'cuda' ], + 'parallel3d': [ 'cuda3d' ], + 'cone': [ 'cuda3d' ], +} + +__combinations_adjoint = { + 'parallel3d': [ 'cuda3d' ], + 'cone': [ 'cuda3d' ], + 'parallel3d_vec': [ 'cuda3d' ], + 'cone_vec': [ 'cuda3d' ], +} __algs = { 'SIRT': 50, 'SART': 10*180, 'CGLS': 30, @@ -137,6 +202,12 @@ for k, l in __combinations.items(): return lambda self: self.single_test(k, v, a, i) setattr(TestRecScale, 'test_' + a + '_' + k + '_' + v, f(k,v,a,i)) +for k, l in __combinations_adjoint.items(): + for v in l: + def g(k, v): + return lambda self: self.single_test_adjoint3D(k, v) + setattr(TestRecScale, 'test_adjoint_' + k + '_' + v, g(k,v)) + if __name__ == '__main__': unittest.main() |