Skip to content

Commit

Permalink
refactor unit tests for unet.
Browse files Browse the repository at this point in the history
  • Loading branch information
hstewart93 committed May 31, 2024
1 parent 5c2095b commit 78ba132
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions continunet/tests/test_network_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def test_build_model(self, input_shape):
test_model = self.model(input_shape).build_model()

assert test_model is not None
assert test_model.input_shape == (None, 256, 256, 1)
assert test_model.output_shape == (None, 256, 256, 1)
assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert len(test_model.layers) == 49

def test_load_weights(self, trained_model, input_shape):
Expand All @@ -27,8 +27,8 @@ def test_load_weights(self, trained_model, input_shape):
test_model.load_weights(trained_model)

assert test_model is not None
assert test_model.input_shape == (None, 256, 256, 1)
assert test_model.output_shape == (None, 256, 256, 1)
assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert len(test_model.layers) == 49
assert test_model.get_weights() is not None

Expand All @@ -39,7 +39,7 @@ def test_decode_image(self, grayscale_image, trained_model, input_shape):

decoded_image = test_model.decode_image()
assert decoded_image is not None
assert decoded_image.shape == (1, 256, 256, 1)
assert decoded_image.shape == (1, input_shape[0], input_shape[1], input_shape[2])

assert decoded_image.min() >= 0
assert decoded_image.max() <= 1
Expand Down

0 comments on commit 78ba132

Please sign in to comment.