Skip to content

Commit

Permalink
set trained model path using importlib.resources
Browse files Browse the repository at this point in the history
  • Loading branch information
hstewart93 committed Jun 18, 2024
1 parent 5b4ab57 commit 56d78b7
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions continunet/finder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Compile ContinUNet modules into Finder class for source finding."""

import importlib.resources
import time

from astropy.table import Table

from continunet.constants import GREEN, RESET, TRAINED_MODEL
from continunet.constants import GREEN, RESET
from continunet.image.fits import ImageSquare
from continunet.image.processing import PreProcessor, PostProcessor
from continunet.network.unet import Unet
Expand Down Expand Up @@ -39,7 +40,8 @@ def find(self, generate_maps=False, use_raw=False):
data = pre_processor.process()

# Run U-Net
unet = Unet(data.shape[1:4], trained_model=TRAINED_MODEL, image=data, layers=self.layers)
with importlib.resources.path("continunet/network", "trained_model.h5") as path:
unet = Unet(data.shape[1:4], trained_model=path, image=data, layers=self.layers)
self.reconstructed_image = unet.decode_image()

# Post-process reconstructed image
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ version = "0.0.3"
authors = [
{ name="Hattie Stewart" }
]
requires-python = ">=3.8"
requires-python = ">=3.10"
dependencies = [
"astropy>=6.0",
"numpy>=1.26",
Expand Down

0 comments on commit 56d78b7

Please sign in to comment.