From 78ba1323be04075ae2be1c06daec4dec54a6d660 Mon Sep 17 00:00:00 2001 From: hstewart93 Date: Fri, 31 May 2024 14:10:39 +0100 Subject: [PATCH] refactor unit tests for unet. --- continunet/tests/test_network_unet.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/continunet/tests/test_network_unet.py b/continunet/tests/test_network_unet.py index 842319c..83b22f0 100644 --- a/continunet/tests/test_network_unet.py +++ b/continunet/tests/test_network_unet.py @@ -16,8 +16,8 @@ def test_build_model(self, input_shape): test_model = self.model(input_shape).build_model() assert test_model is not None - assert test_model.input_shape == (None, 256, 256, 1) - assert test_model.output_shape == (None, 256, 256, 1) + assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2]) + assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2]) assert len(test_model.layers) == 49 def test_load_weights(self, trained_model, input_shape): @@ -27,8 +27,8 @@ def test_load_weights(self, trained_model, input_shape): test_model.load_weights(trained_model) assert test_model is not None - assert test_model.input_shape == (None, 256, 256, 1) - assert test_model.output_shape == (None, 256, 256, 1) + assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2]) + assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2]) assert len(test_model.layers) == 49 assert test_model.get_weights() is not None @@ -39,7 +39,7 @@ def test_decode_image(self, grayscale_image, trained_model, input_shape): decoded_image = test_model.decode_image() assert decoded_image is not None - assert decoded_image.shape == (1, 256, 256, 1) + assert decoded_image.shape == (1, input_shape[0], input_shape[1], input_shape[2]) assert decoded_image.min() >= 0 assert decoded_image.max() <= 1