Skip to content

Commit

Permalink
Merge branch 'main' into unet-pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
hstewart93 committed Jul 5, 2024
2 parents 34e2de0 + 409314e commit 2b524b6
Show file tree
Hide file tree
Showing 17 changed files with 1,861 additions and 50 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,3 @@ jobs:
files: dist/*
name: ${{ steps.branch_info.outputs.TAG }}
tag_name: ${{ steps.branch_info.outputs.TAG }}


24 changes: 24 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Pytest

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
pytest:
name: Run tests
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[ci]
- name: Run tests
run: |
pytest
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include continunet/network/trained_model.h5
112 changes: 104 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
# ContinUNet
[![Pytest](https://github.com/hstewart93/continunet/actions/workflows/pytest.yml/badge.svg)](https://github.com/hstewart93/continunet/actions/workflows/pytest.yml)

Source finding package for radio continuum data powered by U-Net segmentation algorithm.

- [Paper](https://academic.oup.com/rasti/article/3/1/315/7685538?utm_source=advanceaccess&utm_campaign=rasti&utm_medium=email#supplementary-data)
- [Installation](#installation)
- [Developer Installation](#developer-installation)
- [Quickstart](#quickstart)
- [Example Notebook](https://github.com/hstewart93/continunet/tree/finder/continunet/user_example.ipynb)
- [Training Dataset](https://www.kaggle.com/datasets/harrietstewart/continunet)
- [Next Release](#development)

## Installation
The project is available on PyPI, to install latest stable release use:
The project is available on [PyPI](https://pypi.org/project/continunet/), to install latest stable release use:

```pip install continunet```
```bash
pip install continunet
```

To install version in development, use:

```pip install git+https://github.com/hstewart93/continunet```
```bash
pip install git+https://github.com/hstewart93/continunet
```

**ContinUNet has a `Python 3.9` minimum requirement.**

## Developer Installation
If you want to contribute to the repository, install as follows:

Once you have cloned down this repository using `git clone` cd into the app directory eg.
Once you have cloned down this repository using `git clone`, cd into the app directory:

```
```bash
git clone [email protected]:hstewart93/continunet.git
cd continunet
```

Create a virtual environment for development, if you are using bash:

```
```bash
python3 -m venv venv
source venv/bin/activate
pip install -e .[dev,ci]
Expand All @@ -31,10 +48,89 @@ To exit the virtual environment use `deactivate`.

This project used the black auto formatter which can be run on git commit along with flake8 if you install pre-commit. To do this run the following in your terminal from within your virtual environment.

```
```bash
pre-commit install
```

Now pre-commit hooks should run on `git commit`.

The run the test suite use `pytest`.
To run the test suite use `pytest`.

## Quickstart
The package currently support `.FITS` type images. To perform source finding you can import the `finder` module:

```python
from continunet.finder import Finder
```

Load your image file:

```python
finder = Finder("<filepath>")
```

To produce a source catalogue and populate the `Finder` instance:

```python
sources = finder.find()
```

If you want to calculate the model map and residuals image as part of source finding, use the `Finder.find()` method with `generate_maps=True`:

```python
sources = finder.find(generate_maps=True)
model_map = finder.model_map
residuals = finder.residuals
```

Alternatively, manually calculate model map and residual images using:

```python
model_map = finder.get_model_map()
residuals = finder.get_residuals()
```

Useful available attributes of the `Finder` object are:
```python
finder.sources # cleaned source catalogue
finder.reconstructed_image # predicted image reconstructed by unet module
finder.segmentation_map # predicted segmentation map
finder.model_map # model map of cleaned predicted sources
finder.residuals # residual image as numpy array
finder.raw_sources # sources from labelled segmentation map before cleaning
```

Export source catalogue using `finder.export_sources` as `.csv` by default or `.FITS` by setting `export_fits=True`:

```python
finder.export_sources("<filepath>", export_fits=<Boolean>)
```

Source parameters extracted are:

| **Parameter** | **Description** |
|----------------------------|--------------------------------------------------------------------------------------------------------|
| `x_location_original` | x coordinate of the source from the cutout used for inference |
| `y_location_original` | y coordinate of the source from the cutout used for inference |
| `orientation` | orientation of source ellipse in radians |
| `major_axis` | major axis of source ellipse |
| `minor_axis` | minor axis of source ellipse |
| `flux_density_uncorrected` | total intensity of the segmented source region before beam corrections applied |
| `label` | class label in predicted segmentation map |
| `x_location` | x coordinate of the source in the original input image dimensions |
| `y_location` | y coordinate of the source in the original input image dimensions |
| `right_ascension` | RA coordinate of the source in the original input image dimensions |
| `declination` | Dec coordinate of the source in the original input image dimensions |
| `area` | area of source ellipse |
| `position_angle` | position angle of source ellipse in degrees |
| `correction_factor` | correction factor applied to flux density measurement to account for undersampling of synthesised beam |
| `flux_density` | corrected flux density |

## Development
ContinUNet is subject to ongoing development. To see the backlog of features and bug fixes please go to the [project board](https://github.com/users/hstewart93/projects/4/views/1). Please raise any feature requests or bugs as [issues](https://github.com/hstewart93/continunet/issues).

The following features will be added in the next release:

1. Exporting processed images to `.npy` and `.FTIS` [(#33)](https://github.com/hstewart93/continunet/issues/33)
2. Inference for non-square images [(#27)](https://github.com/hstewart93/continunet/issues/27)
3. Taking cutout of `ImageSquare` object before inference [(#28)](https://github.com/hstewart93/continunet/issues/28)
13 changes: 13 additions & 0 deletions continunet/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Constants for ContinuNet."""


TRAINED_MODEL = "continunet/network/trained_model.h5"

# ANSI escape sequences
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
RESET = "\033[0m"
Binary file added continunet/example_image.fits
Binary file not shown.
87 changes: 87 additions & 0 deletions continunet/finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Compile ContinUNet modules into Finder class for source finding."""

import importlib.resources
import time

from astropy.table import Table

from continunet.constants import GREEN, RESET
from continunet.image.fits import ImageSquare
from continunet.image.processing import PreProcessor, PostProcessor
from continunet.network.unet import Unet


class Finder:
"""Class for source finding in radio continuum images."""

def __init__(self, image: str, layers: int = 4):
"""Initialise the Finder class.
Parameters
----------
image : str
The path to the FITS image.
layers : int
The number of encoding and decoding layers in the U-Net model.
Layers is set by default to 4, and cannot currently be changed.
"""
if not image.endswith(".fits"):
raise ValueError("File must be a .fits file.")
self.image = image
if layers != 4:
raise ValueError("Number of layers must be 4.")
self.layers = layers
self.image_object = None
self.sources = None
self.reconstructed_image = None
self.post_processor = None
self.segmentation_map = None
self.model_map = None
self.residuals = None
self.raw_sources = None

def find(self, generate_maps=False, threshold="default"):
"""Find sources in a continuum image."""
start_time = time.time()
# Load image
self.image_object = ImageSquare(self.image)

# Pre-process image
pre_processor = PreProcessor(self.image_object, self.layers)
data = pre_processor.process()

# Run U-Net
with importlib.resources.path("continunet.network", "trained_model.h5") as path:
unet = Unet(data.shape[1:4], trained_model=path, image=data, layers=self.layers)
self.reconstructed_image = unet.decode_image()

# Post-process reconstructed image
self.post_processor = PostProcessor(unet.reconstructed, pre_processor, threshold=threshold)
self.sources = self.post_processor.get_sources()
self.segmentation_map = self.post_processor.segmentation_map
self.raw_sources = self.post_processor.raw_sources

end_time = time.time()
print(
f"{GREEN}ContinUNet found {len(self.sources)} sources "
f"in {(end_time - start_time):.2f} seconds.{RESET}"
)

if generate_maps:
self.model_map = self.post_processor.get_model_map()
self.residuals = self.post_processor.get_residuals()
self.segmentation_map = self.post_processor.segmentation_map

return self.sources

def export_sources(self, path: str, export_fits=False):
"""Export source catalogue to a directory. Use export_fits=True to save as FITS."""
if self.sources is None:
raise ValueError("No sources to export.")
if export_fits:
table = Table.from_pandas(self.sources)
table.write(path, format="fits", overwrite=True)
return self

self.sources.to_csv(path)
return self
File renamed without changes.
50 changes: 27 additions & 23 deletions continunet/images/fits.py → continunet/image/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,51 @@
from astropy.io import fits
from astropy.wcs import WCS

from continunet.constants import CYAN, RESET


class FitsImage(ABC):
"""Abstract model for an image imported from FITS format."""

def __init__(self, path, data=None, header=None, wcs=None, beam_size=None, shape=None):
def __init__(self, path):
self.path = path
self.data = data
self.header = header
self.wcs = wcs
self.beam_size = beam_size
self.shape = shape
self.data = None
self.header = None
self.wcs = None
self.beam_size = None
self.shape = None

self.load()

def load(self):
"""Load fits image from file and populate model args."""
print(f"{CYAN}Loading FITS image from {self.path}...{RESET}")
if not self.path:
raise ValueError("Path to FITS file not provided.")

with fits.open(self.path) as fits_object:

if self.data is None:
self.data = fits_object[0].data
# Convert byte ordering to little-endian as FITS is stored as big-endian
# and is incompatible with torch
self.data = self.data.astype(np.float32)

if self.header is None:
self.header = fits_object[0].header

if self.wcs is None:
self.wcs = WCS(self.header)

if self.beam_size is None:
self.beam_size = self.get_beam_size()

if self.shape is None:
self.shape = self.data.shape
self.data = fits_object[0].data
# Convert byte ordering to little-endian as FITS is stored as big-endian
# and is incompatible with torch
self.data = self.data.astype(np.float32)
self.header = fits_object[0].header
self.wcs = WCS(self.header)
if not self.wcs.has_celestial:
raise ValueError("WCS object does not contain celestial information.")
self.beam_size = self.get_beam_size()
self.shape = self.data.shape
self.check_header()

return self

def check_header(self):
"""Check the header contains required information."""
required_keys = ["CRPIX1", "CRPIX2"]
for key in required_keys:
if key not in self.header:
raise KeyError(f"Header does not contain '{key}' (image information).")

def get_beam_size(self):
"""Return the beam size in arcseconds."""
if "BMAJ" not in self.header:
Expand Down
Loading

0 comments on commit 2b524b6

Please sign in to comment.