Skip to content

Commit

Permalink
Merge pull request #23 from hstewart93/pre-processing
Browse files Browse the repository at this point in the history
Pre processing
  • Loading branch information
hstewart93 authored Jun 18, 2024
2 parents 7169fba + fbfcabd commit 8ccccab
Show file tree
Hide file tree
Showing 10 changed files with 1,260 additions and 93 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Source finding package for radio continuum data powered by U-Net segmentation algorithm.

## Installation
The project is available on PyPI, to install latest stable release use:
The project is available on [PyPI](https://pypi.org/project/continunet/), to install latest stable release use:

```pip install continunet```

Expand Down
903 changes: 903 additions & 0 deletions continunet/example_image.fits

Large diffs are not rendered by default.

File renamed without changes.
47 changes: 24 additions & 23 deletions continunet/images/fits.py → continunet/image/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
class FitsImage(ABC):
"""Abstract model for an image imported from FITS format."""

def __init__(self, path, data=None, header=None, wcs=None, beam_size=None, shape=None):
def __init__(self, path):
self.path = path
self.data = data
self.header = header
self.wcs = wcs
self.beam_size = beam_size
self.shape = shape
self.data = None
self.header = None
self.wcs = None
self.beam_size = None
self.shape = None

self.load()

Expand All @@ -27,26 +27,27 @@ def load(self):

with fits.open(self.path) as fits_object:

if self.data is None:
self.data = fits_object[0].data
# Convert byte ordering to little-endian as FITS is stored as big-endian
# and is incompatible with torch
self.data = self.data.astype(np.float32)

if self.header is None:
self.header = fits_object[0].header

if self.wcs is None:
self.wcs = WCS(self.header)

if self.beam_size is None:
self.beam_size = self.get_beam_size()

if self.shape is None:
self.shape = self.data.shape
self.data = fits_object[0].data
# Convert byte ordering to little-endian as FITS is stored as big-endian
# and is incompatible with torch
self.data = self.data.astype(np.float32)
self.header = fits_object[0].header
self.wcs = WCS(self.header)
if not self.wcs.has_celestial:
raise ValueError("WCS object does not contain celestial information.")
self.beam_size = self.get_beam_size()
self.shape = self.data.shape
self.check_header()

return self

def check_header(self):
"""Check the header contains required information."""
required_keys = ["CRPIX1", "CRPIX2"]
for key in required_keys:
if key not in self.header:
raise KeyError(f"Header does not contain '{key}' (image information).")

def get_beam_size(self):
"""Return the beam size in arcseconds."""
if "BMAJ" not in self.header:
Expand Down
61 changes: 61 additions & 0 deletions continunet/image/pre_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Pre-processing module for images."""

import numpy as np

from astropy.nddata import Cutout2D

from continunet.image.fits import ImageSquare


class PreProcessor:
"""Pre-process image data for inference."""

def __init__(self, image: object, layers: int = 4):
if not isinstance(image, ImageSquare):
raise ValueError("Image must be an ImageSquare object.")
self.image = image
self.layers = layers
self.data = self.image.data
self.wcs = self.image.wcs

def clean_nans(self):
"""Check for NaNs in the image data."""
if np.isnan(self.data).all():
raise ValueError("Image data contains only NaNs.")
if np.isnan(self.data).any():
self.data = np.nan_to_num(self.data, False)
return self.data

def reshape(self):
"""Reshape the image data for the network. Shape must be divisible by 2 ** n layers."""

self.data = np.squeeze(self.data)
self.wcs = self.wcs.celestial
if not isinstance(self.data.shape[0] / 2 ** self.layers, int) or not isinstance(
self.data.shape[1] / 2 ** self.layers, int
):
minimum_size = self.data.shape[0] // (2 ** self.layers) * (2 ** self.layers)
print(f"Trimming image to fit network from {self.data.shape[0]} to {minimum_size}.")
trimmed_image = Cutout2D(
self.data,
(self.image.header["CRPIX1"], self.image.header["CRPIX2"]),
(minimum_size, minimum_size),
wcs=self.wcs,
)
self.data = trimmed_image.data
self.wcs = trimmed_image.wcs

self.data = self.data.reshape(1, *self.data.shape, 1)
return self.data

def normalise(self):
"""Normalise the image data."""
self.data = (self.data - np.min(self.data)) / (np.max(self.data) - np.min(self.data))
return self.data

def process(self):
"""Process the image data."""
self.clean_nans()
self.reshape()
self.normalise()
return self.data
76 changes: 57 additions & 19 deletions continunet/network/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,46 @@ class Unet:

def __init__(
self,
input_shape,
filters=16,
dropout=0.05,
batch_normalisation=True,
trained_model=None,
image=None,
layers=4,
output_activation="sigmoid",
input_shape: tuple,
filters: int = 16,
dropout: float = 0.05,
batch_normalisation: bool = True,
trained_model: str = None,
image: np.ndarray = None,
layers: int = 4,
output_activation: str = "sigmoid",
model: Model = None,
reconstructed: np.ndarray = None,
):
"""
Initialise the UNet model.
Parameters
----------
input_shape : tuple
The shape of the input image.
filters : int
The number of filters to use in the convolutional layers, default is 16.
dropout : float
The dropout rate, default is 0.05.
batch_normalisation : bool
Whether to use batch normalisation, default is True.
trained_model : str
The path to a trained model.
image : np.ndarray
The image to decode. Image must be 2D given as 4D numpy array, e.g. (1, 256, 256, 1).
Image must be grayscale, e.g. not (1, 256, 256, 3). Image array row columns must
be divisible by 2^layers, e.g. 256 % 2^4 == 0.
layers : int
The number of encoding and decoding layers, default is 4.
output_activation : str
The activation function for the output layer, either sigmoid or softmax.
Default is sigmoid.
model : keras.models.Model
A pre-built model, populated by the build_model method.
reconstructed : np.ndarray
The reconstructed image, created by the decode_image method.
"""
self.input_shape = input_shape
self.filters = filters
self.dropout = dropout
Expand All @@ -38,6 +69,10 @@ def __init__(
self.image = image
self.layers = layers
self.output_activation = output_activation
self.model = model
self.reconstructed = reconstructed

self.model = self.build_model()

def convolutional_block(self, input_tensor, filters, kernel_size=3):
"""Convolutional block for UNet."""
Expand Down Expand Up @@ -78,7 +113,7 @@ def build_model(self):
input_image = Input(self.input_shape, name="img")
current = input_image

# Encoder Path
# Encoding Path
convolutional_tensors = []
for layer in range(self.layers):
convolutional_tensor, current = self.encoding_block(
Expand All @@ -91,7 +126,7 @@ def build_model(self):
current, filters=self.filters * 2 ** self.layers
)

# Decoder Path
# Decoding Path
current = latent_convolutional_tensor
for layer in reversed(range(self.layers)):
current = self.decoding_block(
Expand All @@ -104,15 +139,13 @@ def build_model(self):

def compile_model(self):
"""Compile the UNet model."""
model = self.build_model()
model.compile(
self.model.compile(
optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", "iou_score"]
)
return model
return self.model

def decode_image(self):
"""Returns images decoded by a trained model."""
model = self.compile_model()
if self.trained_model is None or self.image is None:
raise ValueError("Trained model and image arguments are required to decode image.")
if isinstance(self.image, np.ndarray) is False:
Expand All @@ -121,8 +154,13 @@ def decode_image(self):
raise ValueError("Image must be 4D numpy array for example (1, 256, 256, 1).")
if self.image.shape[3] != 1:
raise ValueError("Input image must be grayscale.")
if self.image.shape[0] % 256 != 0 and self.image.shape[1] % 256 != 0:
raise ValueError("Image shape should be divisible by 256.")

model.load_weights(self.trained_model)
return model.predict(self.image)
if (
self.image.shape[0] % 2 ** self.layers != 0
and self.image.shape[1] % 2 ** self.layers != 0
):
raise ValueError("Image shape should be divisible by 2^layers.")

self.model = self.compile_model()
self.model.load_weights(self.trained_model)
self.reconstructed = self.model.predict(self.image)
return self.reconstructed
Loading

0 comments on commit 8ccccab

Please sign in to comment.