Skip to content

Commit

Permalink
add method for prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffquinn-msk committed Feb 13, 2024
1 parent 69efc90 commit 321ab94
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 91 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]

[project.scripts]
preprocess = "nuc2seg.cli.preprocess:main"
train_nuc2seg = "nuc2seg.cli.train:main"
train = "nuc2seg.cli.train:main"
predict = "nuc2seg.cli.predict:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
96 changes: 96 additions & 0 deletions src/nuc2seg/cli/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import argparse
import logging
import numpy as np
import torch

from nuc2seg import log_config
from nuc2seg.segment import stitch_predictions
from nuc2seg.unet_model import SparseUNet
from nuc2seg.data import Nuc2SegDataset, TiledDataset

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Evaluate a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--output",
help="Model prediction output in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--model-weights",
help="File to read model weights from.",
type=str,
required=True,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--num-dataloader-workers",
help="Number of workers to use for the data loader.",
type=int,
default=0,
)
return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

np.random.seed(args.seed)

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)

model = SparseUNet(600, ds.n_classes + 2, (args.tile_height, args.tile_width))

model.load_state_dict(torch.load(args.model_weights))

model_predictions = stitch_predictions(model=model, dataloader=tiled_dataset)

model_predictions.save_h5(args.output)
47 changes: 45 additions & 2 deletions src/nuc2seg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@

class Nuc2SegDataset:
def __init__(
self, labels, angles, classes, transcripts, bbox, n_classes, n_genes, resolution
self,
labels,
angles,
classes,
transcripts,
bbox,
n_classes,
n_genes,
resolution,
tile_width: int,
tile_height: int,
tile_overlap: float,
):
self.labels = labels
self.angles = angles
Expand All @@ -22,6 +33,9 @@ def __init__(
self.n_classes = n_classes
self.n_genes = n_genes
self.resolution = resolution
self.tile_width = tile_width
self.tile_height = tile_height
self.tile_overlap = tile_overlap

def save_h5(self, path):
with h5py.File(path, "w") as f:
Expand All @@ -33,6 +47,9 @@ def save_h5(self, path):
f.attrs["n_classes"] = self.n_classes
f.attrs["n_genes"] = self.n_genes
f.attrs["resolution"] = self.resolution
f.attrs["tile_width"] = self.tile_width
f.attrs["tile_height"] = self.tile_height
f.attrs["tile_overlap"] = self.tile_overlap

@property
def x_extent_pixels(self):
Expand Down Expand Up @@ -139,6 +156,10 @@ def __init__(
def __len__(self):
return self._tiler.num_tiles()

@property
def tiler(self):
return self._tiler

@property
def per_tile_class_histograms(self):
class_tiles = (
Expand All @@ -160,7 +181,7 @@ def per_tile_class_histograms(self):
def __getitem__(self, idx):
x1, y1, x2, y2 = next(
generate_tiles(
tiler=self._tiler,
tiler=self.tiler,
x_extent=self.ds.x_extent_pixels,
y_extent=self.ds.y_extent_pixels,
tile_size=(self.tile_height, self.tile_width),
Expand Down Expand Up @@ -211,3 +232,25 @@ def __getitem__(self, idx):
"nucleus_mask": torch.as_tensor(nucleus_mask).bool().contiguous(),
"location": np.array([x1, y1]),
}


class ModelPredictions:
def __init__(self, angles, classes, foreground):

self.angles = angles
self.classes = classes
self.foreground = foreground

def save_h5(self, path):
with h5py.File(path, "w") as f:
f.create_dataset("angles", data=self.angles, compression="gzip")
f.create_dataset("classes", data=self.classes, compression="gzip")
f.create_dataset("foreground", data=self.foreground, compression="gzip")

@staticmethod
def load_h5(path):
with h5py.File(path, "r") as f:
angles = f["angles"][:]
classes = f["classes"][:]
foreground = f["foreground"][:]
return ModelPredictions(angles=angles, classes=classes, foreground=foreground)
36 changes: 21 additions & 15 deletions src/nuc2seg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tqdm

from nuc2seg.preprocessing import pol2cart
from nuc2seg.data import collate_tiles


def dice_coeff(
Expand Down Expand Up @@ -88,20 +89,18 @@ def evaluate(net, dataloader, device, amp):


def plot_predictions(net, dataloader, idx=0, threshold=0.5):
for i, batch in enumerate(dataloader):
if i < idx:
continue
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)
break

batch = collate_tiles([dataloader[idx]])
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)

net.eval()
mask_pred = net(x, y, z).detach().numpy().copy()
Expand Down Expand Up @@ -134,8 +133,15 @@ def plot_predictions(net, dataloader, idx=0, threshold=0.5):
pred_plot[pred_plot == c] = (i % 16) + 4

fig, axarr = plt.subplots(
3, x.shape[0], figsize=(5 * x.shape[0], 10), sharex=True, sharey=True
nrows=3,
ncols=x.shape[0],
figsize=(5 * x.shape[0], 10),
sharex=True,
sharey=True,
)
if len(axarr.shape) == 1:
axarr = axarr[:, None]

for i in range(x.shape[0]):
axarr[0, i].set_title("Labels and transcripts")
axarr[0, i].imshow(label_plot[i], cmap="tab20b", interpolation="none")
Expand Down
8 changes: 6 additions & 2 deletions src/nuc2seg/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from matplotlib import pyplot as plt


def plot_tiling(bboxes):
def plot_tiling(bboxes, output_path):
polygons = [box(*x) for x in bboxes]
gdf = geopandas.GeoDataFrame(geometry=polygons)

fig, ax = plt.subplots()

gdf.boundary.plot(ax=ax, color="red", alpha=0.5)

fig.savefig("/tmp/tiling.pdf")
fig.savefig(output_path)


def plot_model_predictions():
pass
Loading

0 comments on commit 321ab94

Please sign in to comment.