-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into unet-pytorch
- Loading branch information
Showing
17 changed files
with
1,861 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,5 +37,3 @@ jobs: | |
files: dist/* | ||
name: ${{ steps.branch_info.outputs.TAG }} | ||
tag_name: ${{ steps.branch_info.outputs.TAG }} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
name: Pytest | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
pytest: | ||
name: Run tests | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.12' | ||
- name: Install dependencies | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install -e .[ci] | ||
- name: Run tests | ||
run: | | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include continunet/network/trained_model.h5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1,44 @@ | ||
# ContinUNet | ||
[![Pytest](https://github.com/hstewart93/continunet/actions/workflows/pytest.yml/badge.svg)](https://github.com/hstewart93/continunet/actions/workflows/pytest.yml) | ||
|
||
Source finding package for radio continuum data powered by U-Net segmentation algorithm. | ||
|
||
- [Paper](https://academic.oup.com/rasti/article/3/1/315/7685538?utm_source=advanceaccess&utm_campaign=rasti&utm_medium=email#supplementary-data) | ||
- [Installation](#installation) | ||
- [Developer Installation](#developer-installation) | ||
- [Quickstart](#quickstart) | ||
- [Example Notebook](https://github.com/hstewart93/continunet/tree/finder/continunet/user_example.ipynb) | ||
- [Training Dataset](https://www.kaggle.com/datasets/harrietstewart/continunet) | ||
- [Next Release](#development) | ||
|
||
## Installation | ||
The project is available on PyPI, to install latest stable release use: | ||
The project is available on [PyPI](https://pypi.org/project/continunet/), to install latest stable release use: | ||
|
||
```pip install continunet``` | ||
```bash | ||
pip install continunet | ||
``` | ||
|
||
To install version in development, use: | ||
|
||
```pip install git+https://github.com/hstewart93/continunet``` | ||
```bash | ||
pip install git+https://github.com/hstewart93/continunet | ||
``` | ||
|
||
**ContinUNet has a `Python 3.9` minimum requirement.** | ||
|
||
## Developer Installation | ||
If you want to contribute to the repository, install as follows: | ||
|
||
Once you have cloned down this repository using `git clone` cd into the app directory eg. | ||
Once you have cloned down this repository using `git clone`, cd into the app directory: | ||
|
||
``` | ||
```bash | ||
git clone [email protected]:hstewart93/continunet.git | ||
cd continunet | ||
``` | ||
|
||
Create a virtual environment for development, if you are using bash: | ||
|
||
``` | ||
```bash | ||
python3 -m venv venv | ||
source venv/bin/activate | ||
pip install -e .[dev,ci] | ||
|
@@ -31,10 +48,89 @@ To exit the virtual environment use `deactivate`. | |
|
||
This project used the black auto formatter which can be run on git commit along with flake8 if you install pre-commit. To do this run the following in your terminal from within your virtual environment. | ||
|
||
``` | ||
```bash | ||
pre-commit install | ||
``` | ||
|
||
Now pre-commit hooks should run on `git commit`. | ||
|
||
The run the test suite use `pytest`. | ||
To run the test suite use `pytest`. | ||
|
||
## Quickstart | ||
The package currently support `.FITS` type images. To perform source finding you can import the `finder` module: | ||
|
||
```python | ||
from continunet.finder import Finder | ||
``` | ||
|
||
Load your image file: | ||
|
||
```python | ||
finder = Finder("<filepath>") | ||
``` | ||
|
||
To produce a source catalogue and populate the `Finder` instance: | ||
|
||
```python | ||
sources = finder.find() | ||
``` | ||
|
||
If you want to calculate the model map and residuals image as part of source finding, use the `Finder.find()` method with `generate_maps=True`: | ||
|
||
```python | ||
sources = finder.find(generate_maps=True) | ||
model_map = finder.model_map | ||
residuals = finder.residuals | ||
``` | ||
|
||
Alternatively, manually calculate model map and residual images using: | ||
|
||
```python | ||
model_map = finder.get_model_map() | ||
residuals = finder.get_residuals() | ||
``` | ||
|
||
Useful available attributes of the `Finder` object are: | ||
```python | ||
finder.sources # cleaned source catalogue | ||
finder.reconstructed_image # predicted image reconstructed by unet module | ||
finder.segmentation_map # predicted segmentation map | ||
finder.model_map # model map of cleaned predicted sources | ||
finder.residuals # residual image as numpy array | ||
finder.raw_sources # sources from labelled segmentation map before cleaning | ||
``` | ||
|
||
Export source catalogue using `finder.export_sources` as `.csv` by default or `.FITS` by setting `export_fits=True`: | ||
|
||
```python | ||
finder.export_sources("<filepath>", export_fits=<Boolean>) | ||
``` | ||
|
||
Source parameters extracted are: | ||
|
||
| **Parameter** | **Description** | | ||
|----------------------------|--------------------------------------------------------------------------------------------------------| | ||
| `x_location_original` | x coordinate of the source from the cutout used for inference | | ||
| `y_location_original` | y coordinate of the source from the cutout used for inference | | ||
| `orientation` | orientation of source ellipse in radians | | ||
| `major_axis` | major axis of source ellipse | | ||
| `minor_axis` | minor axis of source ellipse | | ||
| `flux_density_uncorrected` | total intensity of the segmented source region before beam corrections applied | | ||
| `label` | class label in predicted segmentation map | | ||
| `x_location` | x coordinate of the source in the original input image dimensions | | ||
| `y_location` | y coordinate of the source in the original input image dimensions | | ||
| `right_ascension` | RA coordinate of the source in the original input image dimensions | | ||
| `declination` | Dec coordinate of the source in the original input image dimensions | | ||
| `area` | area of source ellipse | | ||
| `position_angle` | position angle of source ellipse in degrees | | ||
| `correction_factor` | correction factor applied to flux density measurement to account for undersampling of synthesised beam | | ||
| `flux_density` | corrected flux density | | ||
|
||
## Development | ||
ContinUNet is subject to ongoing development. To see the backlog of features and bug fixes please go to the [project board](https://github.com/users/hstewart93/projects/4/views/1). Please raise any feature requests or bugs as [issues](https://github.com/hstewart93/continunet/issues). | ||
|
||
The following features will be added in the next release: | ||
|
||
1. Exporting processed images to `.npy` and `.FTIS` [(#33)](https://github.com/hstewart93/continunet/issues/33) | ||
2. Inference for non-square images [(#27)](https://github.com/hstewart93/continunet/issues/27) | ||
3. Taking cutout of `ImageSquare` object before inference [(#28)](https://github.com/hstewart93/continunet/issues/28) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
"""Constants for ContinuNet.""" | ||
|
||
|
||
TRAINED_MODEL = "continunet/network/trained_model.h5" | ||
|
||
# ANSI escape sequences | ||
RED = "\033[31m" | ||
GREEN = "\033[32m" | ||
YELLOW = "\033[33m" | ||
BLUE = "\033[34m" | ||
MAGENTA = "\033[35m" | ||
CYAN = "\033[36m" | ||
RESET = "\033[0m" |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
"""Compile ContinUNet modules into Finder class for source finding.""" | ||
|
||
import importlib.resources | ||
import time | ||
|
||
from astropy.table import Table | ||
|
||
from continunet.constants import GREEN, RESET | ||
from continunet.image.fits import ImageSquare | ||
from continunet.image.processing import PreProcessor, PostProcessor | ||
from continunet.network.unet import Unet | ||
|
||
|
||
class Finder: | ||
"""Class for source finding in radio continuum images.""" | ||
|
||
def __init__(self, image: str, layers: int = 4): | ||
"""Initialise the Finder class. | ||
Parameters | ||
---------- | ||
image : str | ||
The path to the FITS image. | ||
layers : int | ||
The number of encoding and decoding layers in the U-Net model. | ||
Layers is set by default to 4, and cannot currently be changed. | ||
""" | ||
if not image.endswith(".fits"): | ||
raise ValueError("File must be a .fits file.") | ||
self.image = image | ||
if layers != 4: | ||
raise ValueError("Number of layers must be 4.") | ||
self.layers = layers | ||
self.image_object = None | ||
self.sources = None | ||
self.reconstructed_image = None | ||
self.post_processor = None | ||
self.segmentation_map = None | ||
self.model_map = None | ||
self.residuals = None | ||
self.raw_sources = None | ||
|
||
def find(self, generate_maps=False, threshold="default"): | ||
"""Find sources in a continuum image.""" | ||
start_time = time.time() | ||
# Load image | ||
self.image_object = ImageSquare(self.image) | ||
|
||
# Pre-process image | ||
pre_processor = PreProcessor(self.image_object, self.layers) | ||
data = pre_processor.process() | ||
|
||
# Run U-Net | ||
with importlib.resources.path("continunet.network", "trained_model.h5") as path: | ||
unet = Unet(data.shape[1:4], trained_model=path, image=data, layers=self.layers) | ||
self.reconstructed_image = unet.decode_image() | ||
|
||
# Post-process reconstructed image | ||
self.post_processor = PostProcessor(unet.reconstructed, pre_processor, threshold=threshold) | ||
self.sources = self.post_processor.get_sources() | ||
self.segmentation_map = self.post_processor.segmentation_map | ||
self.raw_sources = self.post_processor.raw_sources | ||
|
||
end_time = time.time() | ||
print( | ||
f"{GREEN}ContinUNet found {len(self.sources)} sources " | ||
f"in {(end_time - start_time):.2f} seconds.{RESET}" | ||
) | ||
|
||
if generate_maps: | ||
self.model_map = self.post_processor.get_model_map() | ||
self.residuals = self.post_processor.get_residuals() | ||
self.segmentation_map = self.post_processor.segmentation_map | ||
|
||
return self.sources | ||
|
||
def export_sources(self, path: str, export_fits=False): | ||
"""Export source catalogue to a directory. Use export_fits=True to save as FITS.""" | ||
if self.sources is None: | ||
raise ValueError("No sources to export.") | ||
if export_fits: | ||
table = Table.from_pandas(self.sources) | ||
table.write(path, format="fits", overwrite=True) | ||
return self | ||
|
||
self.sources.to_csv(path) | ||
return self |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.