diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/astra/__init__.py | 8 | ||||
-rw-r--r-- | python/astra/astra.py | 4 | ||||
-rw-r--r-- | python/astra/astra_c.pyx | 21 |
3 files changed, 24 insertions, 9 deletions
diff --git a/python/astra/__init__.py b/python/astra/__init__.py index 10ed74d..515d9a2 100644 --- a/python/astra/__init__.py +++ b/python/astra/__init__.py @@ -39,7 +39,7 @@ from . import log from .optomo import OpTomo import os -try: - astra.set_gpu_index(int(os.environ['ASTRA_GPU_INDEX'])) -except KeyError: - pass + +if 'ASTRA_GPU_INDEX' in os.environ: + L = [ int(x) for x in os.environ['ASTRA_GPU_INDEX'].split(',') ] + astra.set_gpu_index(L) diff --git a/python/astra/astra.py b/python/astra/astra.py index 26b1ff0..9328b6b 100644 --- a/python/astra/astra.py +++ b/python/astra/astra.py @@ -49,10 +49,10 @@ def version(printToScreen=False): """ return a.version(printToScreen) -def set_gpu_index(idx): +def set_gpu_index(idx, memory=0): """Set default GPU index to use. :param idx: GPU index :type idx: :class:`int` """ - a.set_gpu_index(idx) + a.set_gpu_index(idx, memory) diff --git a/python/astra/astra_c.pyx b/python/astra/astra_c.pyx index 5075fed..65192b5 100644 --- a/python/astra/astra_c.pyx +++ b/python/astra/astra_c.pyx @@ -31,6 +31,7 @@ import six from .utils import wrap_from_bytes from libcpp.string cimport string +from libcpp.vector cimport vector from libcpp cimport bool cdef extern from "astra/Globals.h" namespace "astra": int getVersion() @@ -43,6 +44,12 @@ IF HAVE_CUDA==True: ELSE: def setGPUIndex(): pass +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra": + cdef cppclass SGPUParams: + vector[int] GPUIndices + size_t memory +cdef extern from "astra/CompositeGeometryManager.h" namespace "astra::CCompositeGeometryManager": + void setGlobalGPUParams(SGPUParams&) def credits(): six.print_("""The ASTRA Toolbox has been developed at the University of Antwerp and CWI, Amsterdam by @@ -70,8 +77,16 @@ def version(printToScreen=False): else: return getVersion() -def set_gpu_index(idx): +def set_gpu_index(idx, memory=0): + import types + import collections + cdef SGPUParams params if use_cuda()==True: - ret = setGPUIndex(idx) + if not isinstance(idx, collections.Iterable) or isinstance(idx, types.StringTypes): + idx = (idx,) + params.memory = memory + params.GPUIndices = idx + setGlobalGPUParams(params) + ret = setGPUIndex(params.GPUIndices[0]) if not ret: - six.print_("Failed to set GPU " + str(idx)) + six.print_("Failed to set GPU " + str(params.GPUIndices[0])) |