Skip to content

Commit

Permalink
Merge pull request #15 from hstewart93/network-models
Browse files Browse the repository at this point in the history
Network models
  • Loading branch information
hstewart93 authored Jun 5, 2024
2 parents 9e4b08e + 78ba132 commit 7169fba
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 1 deletion.
Empty file added continunet/network/__init__.py
Empty file.
Binary file added continunet/network/trained_model.h5
Binary file not shown.
128 changes: 128 additions & 0 deletions continunet/network/unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""UNet model for image segmentation in keras."""

import numpy as np

from keras.layers import (
Activation,
BatchNormalization,
Concatenate,
Conv2D,
Conv2DTranspose,
Dropout,
Input,
MaxPooling2D,
)
from keras.models import Model
from tensorflow.keras.optimizers import Adam


class Unet:
"""UNet model for image segmentation."""

def __init__(
self,
input_shape,
filters=16,
dropout=0.05,
batch_normalisation=True,
trained_model=None,
image=None,
layers=4,
output_activation="sigmoid",
):
self.input_shape = input_shape
self.filters = filters
self.dropout = dropout
self.batch_normalisation = batch_normalisation
self.trained_model = trained_model
self.image = image
self.layers = layers
self.output_activation = output_activation

def convolutional_block(self, input_tensor, filters, kernel_size=3):
"""Convolutional block for UNet."""
convolutional_layer = Conv2D(
filters=filters,
kernel_size=(kernel_size, kernel_size),
kernel_initializer="he_normal",
padding="same",
)
batch_normalisation_layer = BatchNormalization()
relu_layer = Activation("relu")

if self.batch_normalisation:
return relu_layer(batch_normalisation_layer(convolutional_layer(input_tensor)))
return relu_layer(convolutional_layer(input_tensor))

def encoding_block(self, input_tensor, filters, kernel_size=3):
"""Encoding block for UNet."""
convolutional_block = self.convolutional_block(input_tensor, filters, kernel_size)
max_pooling_layer = MaxPooling2D((2, 2), padding="same")
dropout_layer = Dropout(self.dropout)

return convolutional_block, dropout_layer(max_pooling_layer(convolutional_block))

def decoding_block(self, input_tensor, concat_tensor, filters, kernel_size=3):
"""Decoding block for UNet."""
transpose_convolutional_layer = Conv2DTranspose(
filters, (3, 3), strides=(2, 2), padding="same"
)
skip_connection = Concatenate()(
[transpose_convolutional_layer(input_tensor), concat_tensor]
)
dropout_layer = Dropout(self.dropout)
return self.convolutional_block(dropout_layer(skip_connection), filters, kernel_size)

def build_model(self):
"""Build the UNet model."""
input_image = Input(self.input_shape, name="img")
current = input_image

# Encoder Path
convolutional_tensors = []
for layer in range(self.layers):
convolutional_tensor, current = self.encoding_block(
current, self.filters * (2 ** layer)
)
convolutional_tensors.append((convolutional_tensor))

# Latent Convolutional Block
latent_convolutional_tensor = self.convolutional_block(
current, filters=self.filters * 2 ** self.layers
)

# Decoder Path
current = latent_convolutional_tensor
for layer in reversed(range(self.layers)):
current = self.decoding_block(
current, convolutional_tensors[layer], self.filters * (2 ** layer)
)

outputs = Conv2D(1, (1, 1), activation=self.output_activation)(current)
model = Model(inputs=[input_image], outputs=[outputs])
return model

def compile_model(self):
"""Compile the UNet model."""
model = self.build_model()
model.compile(
optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", "iou_score"]
)
return model

def decode_image(self):
"""Returns images decoded by a trained model."""
model = self.compile_model()
if self.trained_model is None or self.image is None:
raise ValueError("Trained model and image arguments are required to decode image.")
if isinstance(self.image, np.ndarray) is False:
raise TypeError("Image must be a numpy array.")
if len(self.image.shape) != 4:
raise ValueError("Image must be 4D numpy array for example (1, 256, 256, 1).")
if self.image.shape[3] != 1:
raise ValueError("Input image must be grayscale.")
if self.image.shape[0] % 256 != 0 and self.image.shape[1] % 256 != 0:
raise ValueError("Image shape should be divisible by 256.")

model.load_weights(self.trained_model)
return model.predict(self.image)
51 changes: 51 additions & 0 deletions continunet/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,54 @@ def fits_file(tmp_path):

yield path
path.unlink()


@pytest.fixture
def trained_model():
"""Fixture for a trained model."""
return "continunet/network/trained_model.h5"


@pytest.fixture
def grayscale_image():
"""Generate a random 256x256x1 image array."""
image = np.random.randint(0, 1, size=(256, 256, 1), dtype=np.uint8)
return image.reshape((1, 256, 256, 1))


@pytest.fixture
def colour_image():
"""Generate a random 256x256x3 image array."""
image = np.random.randint(0, 1, size=(256, 256, 3), dtype=np.uint8)
return image.reshape((1, 256, 256, 3))


@pytest.fixture
def invalid_image():
"""Generate an invalid shape image array, not divisble by 256."""
image = np.random.randint(0, 1, size=(255, 255, 1), dtype=np.uint8)
return image.reshape((1, 255, 255, 1))


@pytest.fixture
def input_shape():
"""Fixture for the input shape."""
return (256, 256, 1)


@pytest.fixture
def grayscale_image_input_shape(grayscale_image):
"""Fixture for the input shape."""
return grayscale_image.shape[1:]


@pytest.fixture
def colour_image_input_shape(colour_image):
"""Fixture for the input shape."""
return colour_image.shape[1:]


@pytest.fixture
def invalid_image_input_shape(invalid_image):
"""Fixture for the input shape."""
return invalid_image.shape[1:]
File renamed without changes.
91 changes: 91 additions & 0 deletions continunet/tests/test_network_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Tests for the UNet model."""

import pytest

from continunet.network.unet import Unet


class TestUnet:
"""Tests for the UNet model."""

model = Unet

def test_build_model(self, input_shape):
"""Test the compile_model method"""

test_model = self.model(input_shape).build_model()

assert test_model is not None
assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert len(test_model.layers) == 49

def test_load_weights(self, trained_model, input_shape):
"""Test the load_weights method"""

test_model = self.model(input_shape).build_model()
test_model.load_weights(trained_model)

assert test_model is not None
assert test_model.input_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert test_model.output_shape == (None, input_shape[0], input_shape[1], input_shape[2])
assert len(test_model.layers) == 49
assert test_model.get_weights() is not None

def test_decode_image(self, grayscale_image, trained_model, input_shape):
"""Test the decode_image method"""

test_model = self.model(input_shape, image=grayscale_image, trained_model=trained_model)

decoded_image = test_model.decode_image()
assert decoded_image is not None
assert decoded_image.shape == (1, input_shape[0], input_shape[1], input_shape[2])

assert decoded_image.min() >= 0
assert decoded_image.max() <= 1

def test_decode_image_invalid_image_type(self, trained_model, input_shape):
"""Test the decode_image method with invalid image type"""

test_model = self.model(input_shape, image="invalid", trained_model=trained_model)

with pytest.raises(TypeError):
test_model.decode_image()

def test_decode_image_invalid_input_shape(
self, invalid_image, trained_model, invalid_image_input_shape
):
"""Test the decode_image method with invalid input shape"""

test_model = self.model(
invalid_image_input_shape, image=invalid_image, trained_model=trained_model
)

with pytest.raises(ValueError):
test_model.decode_image()

def test_decode_image_no_trained_model(self, grayscale_image, grayscale_image_input_shape):
"""Test the decode_image method with no trained model"""

test_model = self.model(grayscale_image_input_shape, image=grayscale_image)

with pytest.raises(ValueError):
test_model.decode_image()

def test_decode_image_no_image(self, trained_model, input_shape):
"""Test the decode_image method with no image"""

test_model = self.model(input_shape, trained_model=trained_model)

with pytest.raises(ValueError):
test_model.decode_image()

def test_decode_image_colour_image(self, trained_model, colour_image, colour_image_input_shape):
"""Test the decode_image method with a colour image"""

test_model = self.model(
colour_image_input_shape, image=colour_image, trained_model=trained_model
)

with pytest.raises(ValueError):
test_model.decode_image()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ requires-python = ">=3.8"
dependencies = [
"astropy>=6.0",
"numpy>=1.26",
"keras>=3.3",
"tensorflow>=2.16",
]

[project.optional-dependencies]
Expand All @@ -26,7 +28,6 @@ dev = [
"flake8",
"pre-commit",
"pytest",
"ipython",
]
ci = [
"twine",
Expand Down

0 comments on commit 7169fba

Please sign in to comment.