From ae33f713a2dea236e28145dcd6007589feb618ed Mon Sep 17 00:00:00 2001
From: Willem Jan Palenstijn <Willem.Jan.Palenstijn@cwi.nl>
Date: Wed, 8 Feb 2017 10:47:22 +0100
Subject: Make typechecks in data3d.create more robust

---
 python/astra/data3d_c.pyx | 28 ++++++++++++++++++----------
 1 file changed, 18 insertions(+), 10 deletions(-)

diff --git a/python/astra/data3d_c.pyx b/python/astra/data3d_c.pyx
index 56247de..73e75b9 100644
--- a/python/astra/data3d_c.pyx
+++ b/python/astra/data3d_c.pyx
@@ -77,12 +77,15 @@ def create(datatype,geometry,data=None, link=False):
         cdef MemHandle3D hnd
 
     if link:
-        if isinstance(data, GPULink):
+        if isinstance(data, np.ndarray):
+            if data.shape != geom_size(geometry):
+                raise Exception("The dimensions of the data do not match those specified in the geometry.")
+        elif isinstance(data, GPULink):
             s = geom_size(geometry)
             if geom_size(geometry) != ( data.z, data.y, data.x ):
                 raise Exception("The dimensions of the data do not match those specified in the geometry.")
-        elif data.shape!=geom_size(geometry):
-            raise Exception("The dimensions of the data do not match those specified in the geometry.")
+        else:
+            raise TypeError("data should be a numpy.ndarray or a GPULink object")
 
     if datatype == '-vol':
         cfg = utils.dictToConfig(six.b('VolumeGeometry'), geometry)
@@ -92,7 +95,10 @@ def create(datatype,geometry,data=None, link=False):
             del pGeometry
             raise Exception('Geometry class not initialized.')
         if link:
-            if isinstance(data, GPULink):
+            if isinstance(data, np.ndarray):
+                pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
+                pDataObject3D = <CFloat32Data3D * > new CFloat32VolumeData3DMemory(pGeometry, pCustom)
+            elif isinstance(data, GPULink):
                 IF HAVE_CUDA==True:
                     s = geom_size(geometry)
                     hnd = wrapHandle(<float*>PyLong_AsVoidPtr(data.ptr), data.x, data.y, data.z, data.pitch/4)
@@ -100,8 +106,7 @@ def create(datatype,geometry,data=None, link=False):
                 ELSE:
                     raise NotImplementedError("CUDA support is not enabled in ASTRA")
             else:
-                pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
-                pDataObject3D = <CFloat32Data3D * > new CFloat32VolumeData3DMemory(pGeometry, pCustom)
+                raise TypeError("data should be a numpy.ndarray or a GPULink object")
         else:
             pDataObject3D = <CFloat32Data3D * > new CFloat32VolumeData3DMemory(pGeometry)
         del cfg
@@ -125,7 +130,10 @@ def create(datatype,geometry,data=None, link=False):
             del ppGeometry
             raise Exception('Geometry class not initialized.')
         if link:
-            if isinstance(data, GPULink):
+            if isinstance(data, np.ndarray):
+                pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
+                pDataObject3D = <CFloat32Data3D * > new CFloat32ProjectionData3DMemory(ppGeometry, pCustom)
+            elif isinstance(data, GPULink):
                 IF HAVE_CUDA==True:
                     s = geom_size(geometry)
                     hnd = wrapHandle(<float*>PyLong_AsVoidPtr(data.ptr), data.x, data.y, data.z, data.pitch/4)
@@ -133,8 +141,7 @@ def create(datatype,geometry,data=None, link=False):
                 ELSE:
                     raise NotImplementedError("CUDA support is not enabled in ASTRA")
             else:
-                pCustom = <CFloat32CustomMemory*> new CFloat32CustomPython(data)
-                pDataObject3D = <CFloat32Data3D * > new CFloat32ProjectionData3DMemory(ppGeometry, pCustom)
+                raise TypeError("data should be a numpy.ndarray or a GPULink object")
         else:
             pDataObject3D = <CFloat32Data3DMemory * > new CFloat32ProjectionData3DMemory(ppGeometry)
         del ppGeometry
@@ -146,7 +153,8 @@ def create(datatype,geometry,data=None, link=False):
         del pDataObject3D
         raise Exception("Couldn't initialize data object.")
 
-    if not link: fillDataObject(dynamic_cast_mem(pDataObject3D), data)
+    if not link:
+        fillDataObject(dynamic_cast_mem(pDataObject3D), data)
 
     return man3d.store(<CFloat32Data3D*>pDataObject3D)
 
-- 
cgit v1.2.3