From 35052afe6c27119b7d1d58a035d17be043d686f2 Mon Sep 17 00:00:00 2001 From: Willem Jan Palenstijn Date: Tue, 2 Apr 2019 15:58:14 +0200 Subject: Add test for reconstruction scaling --- tests/python/test_rec_scaling.py | 79 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 tests/python/test_rec_scaling.py (limited to 'tests') diff --git a/tests/python/test_rec_scaling.py b/tests/python/test_rec_scaling.py new file mode 100644 index 0000000..33d09f9 --- /dev/null +++ b/tests/python/test_rec_scaling.py @@ -0,0 +1,79 @@ +import numpy as np +import unittest +import astra +import math +import pylab + +DISPLAY=False + +def VolumeGeometries(): + for s in [0.8, 1.0, 1.25]: + yield astra.create_vol_geom(128, 128, -64*s, 64*s, -64*s, 64*s) + +def ProjectionGeometries(type): + if type == 'parallel': + for dU in [0.8, 1.0, 1.25]: + yield astra.create_proj_geom('parallel', dU, 256, np.linspace(0,np.pi,180,False)) + elif type == 'fanflat': + for dU in [0.8, 1.0, 1.25]: + for src in [500, 1000]: + for det in [0, 250, 500]: + yield astra.create_proj_geom('fanflat', dU, 256, np.linspace(0,2*np.pi,180,False), src, det) + + +class Test2DRecScale(unittest.TestCase): + def single_test(self, geom_type, proj_type, alg, iters): + for vg in VolumeGeometries(): + for pg in ProjectionGeometries(geom_type): + vol = np.zeros((128,128)) + vol[50:70,50:70] = 1 + proj_id = astra.create_projector(proj_type, pg, vg) + sino_id, sinogram = astra.create_sino(vol, proj_id) + rec_id = astra.data2d.create('-vol', vg, 0.0 if 'EM' not in alg else 1.0) + + cfg = astra.astra_dict(alg) + cfg['ReconstructionDataId'] = rec_id + cfg['ProjectionDataId'] = sino_id + cfg['ProjectorId'] = proj_id + alg_id = astra.algorithm.create(cfg) + + astra.algorithm.run(alg_id, iters) + rec = astra.data2d.get(rec_id) + astra.astra.delete([sino_id, alg_id, alg_id, proj_id]) + val = np.sum(rec[55:65,55:65]) / 100. + TOL = 5e-2 + if DISPLAY and abs(val-1.0) >= TOL: + print(geom_type, proj_type, alg, vg, pg) + print(val) + pylab.gray() + pylab.imshow(rec) + pylab.show() + self.assertTrue(abs(val-1.0) < TOL) + + +__combinations = { + 'parallel': [ 'line', 'linear', 'distance_driven', 'strip', 'cuda' ], + 'fanflat': [ 'line_fanflat', 'strip_fanflat', 'cuda' ], +# 'fanflat': [ 'cuda' ], + } + +__algs = { + 'SIRT': 50, 'SART': 10*180, 'CGLS': 30, 'FBP': 1 +} + +__algs_CUDA = { + 'SIRT_CUDA': 50, 'SART_CUDA': 10*180, 'CGLS_CUDA': 30, 'EM_CUDA': 50, + 'FBP_CUDA': 1 +} + +for k, l in __combinations.items(): + for v in l: + A = __algs if v != 'cuda' else __algs_CUDA + for a, i in A.items(): + def f(k, v, a, i): + return lambda self: self.single_test(k, v, a, i) + setattr(Test2DRecScale, 'test_' + a + '_' + k + '_' + v, f(k,v,a,i)) + +if __name__ == '__main__': + unittest.main() + -- cgit v1.2.3