From 4235e71a4de64130228097907134f78bd9e22798 Mon Sep 17 00:00:00 2001
From: Gemma Fardell <gfardell@stfc.ac.uk>
Date: Mon, 24 Jun 2019 14:44:06 +0100
Subject: Unit tests on test image loading and image size

---
 Wrappers/Python/test/test_TestData.py | 107 ++++++++++++++++++++++------------
 1 file changed, 69 insertions(+), 38 deletions(-)

(limited to 'Wrappers/Python')

diff --git a/Wrappers/Python/test/test_TestData.py b/Wrappers/Python/test/test_TestData.py
index c4a0a70..fa0b98f 100755
--- a/Wrappers/Python/test/test_TestData.py
+++ b/Wrappers/Python/test/test_TestData.py
@@ -1,10 +1,9 @@
 import numpy
 from ccpi.framework import TestData
 import os, sys
-
 sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 from testclass import CCPiTestClass
-
+import unittest
 
 class TestTestData(CCPiTestClass):
     def test_random_noise(self):
@@ -17,67 +16,99 @@ class TestTestData(CCPiTestClass):
         self.assertAlmostEqual(norm, 48.881268, places=4)
 
     def test_load_CAMERA(self):
-        loader = TestData()
-        image = loader.load(TestData.CAMERA)
 
-        if image:
-            res = True
-        else:
-            res = False
+        loader = TestData()
+        res = False
+        try:
+            image = loader.load(TestData.CAMERA)
+            if (image.shape[0] == 512) and (image.shape[1] == 512):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
 
+
     def test_load_BOAT(self):
         loader = TestData()
-        image = loader.load(TestData.BOAT)
-
-        if image:
-            res = True
-        else:
-            res = False
+        res = False
+        try:
+            image = loader.load(TestData.BOAT)
+            if (image.shape[0] == 512) and (image.shape[1] == 512):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
 
     def test_load_PEPPERS(self):
         loader = TestData()
-        image = loader.load(TestData.PEPPERS)
-
-        if image:
-            res = True
-        else:
-            res = False
+        res = False
+        try:
+            image = loader.load(TestData.PEPPERS)
+            if (image.shape[0] == 512) and (image.shape[1] == 512) and (image.shape[2] == 3):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
 
     def test_load_RESOLUTION_CHART(self):
         loader = TestData()
-        image = loader.load(TestData.RESOLUTION_CHART)
-
-        if image:
-            res = True
-        else:
-            res = False
+        res = False
+        try:
+            image = loader.load(TestData.RESOLUTION_CHART)
+            if (image.shape[0] == 512) and (image.shape[1] == 512):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
 
     def test_load_SIMPLE_PHANTOM_2D(self):
         loader = TestData()
-        image = loader.load(TestData.SIMPLE_PHANTOM_2D)
-
-        if image:
-            res = True
-        else:
-            res = False
+        res = False
+        try:
+            image = loader.load(TestData.SIMPLE_PHANTOM_2D)
+            if (image.shape[0] == 512) and (image.shape[1] == 512):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
 
     def test_load_SHAPES(self):
         loader = TestData()
-        image = loader.load(TestData.SHAPES)
-
-        if image:
-            res = True
-        else:
-            res = False
+        res = False
+        try:
+            image = loader.load(TestData.SHAPES)
+            if (image.shape[0] == 200) and (image.shape[1] == 300):
+                res = True
+            else:
+                print("Image dimension mismatch")
+        except FileNotFoundError:
+            print("File not found")
+        except:
+            print("Failed to load file")
 
         self.assertTrue(res)
-- 
cgit v1.2.3