Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Prediction Top Level CLI and Method #4

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]

[project.scripts]
preprocess = "nuc2seg.cli.preprocess:main"
train_nuc2seg = "nuc2seg.cli.train:main"
train = "nuc2seg.cli.train:main"
predict = "nuc2seg.cli.predict:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
94 changes: 94 additions & 0 deletions src/nuc2seg/cli/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import argparse
import logging
import numpy as np
import torch

from nuc2seg import log_config
from nuc2seg.segment import stitch_predictions
from nuc2seg.unet_model import SparseUNet
from nuc2seg.data import Nuc2SegDataset, TiledDataset

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Evaluate a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--output",
help="Model prediction output in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--model-weights",
help="File to read model weights from.",
type=str,
required=True,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--num-dataloader-workers",
help="Number of workers to use for the data loader.",
type=int,
default=0,
)
return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)

model = SparseUNet(600, ds.n_classes + 2, (args.tile_height, args.tile_width))

model.load_state_dict(torch.load(args.model_weights))

model_predictions = stitch_predictions(model=model, dataloader=tiled_dataset)

model_predictions.save_h5(args.output)
28 changes: 27 additions & 1 deletion src/nuc2seg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def __init__(
def __len__(self):
return self._tiler.num_tiles()

@property
def tiler(self):
return self._tiler

@property
def per_tile_class_histograms(self):
class_tiles = (
Expand All @@ -160,7 +164,7 @@ def per_tile_class_histograms(self):
def __getitem__(self, idx):
x1, y1, x2, y2 = next(
generate_tiles(
tiler=self._tiler,
tiler=self.tiler,
x_extent=self.ds.x_extent_pixels,
y_extent=self.ds.y_extent_pixels,
tile_size=(self.tile_height, self.tile_width),
Expand Down Expand Up @@ -211,3 +215,25 @@ def __getitem__(self, idx):
"nucleus_mask": torch.as_tensor(nucleus_mask).bool().contiguous(),
"location": np.array([x1, y1]),
}


class ModelPredictions:
def __init__(self, angles, classes, foreground):

self.angles = angles
self.classes = classes
self.foreground = foreground

def save_h5(self, path):
with h5py.File(path, "w") as f:
f.create_dataset("angles", data=self.angles, compression="gzip")
f.create_dataset("classes", data=self.classes, compression="gzip")
f.create_dataset("foreground", data=self.foreground, compression="gzip")

@staticmethod
def load_h5(path):
with h5py.File(path, "r") as f:
angles = f["angles"][:]
classes = f["classes"][:]
foreground = f["foreground"][:]
return ModelPredictions(angles=angles, classes=classes, foreground=foreground)
36 changes: 21 additions & 15 deletions src/nuc2seg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tqdm

from nuc2seg.preprocessing import pol2cart
from nuc2seg.data import collate_tiles


def dice_coeff(
Expand Down Expand Up @@ -88,20 +89,18 @@ def evaluate(net, dataloader, device, amp):


def plot_predictions(net, dataloader, idx=0, threshold=0.5):
for i, batch in enumerate(dataloader):
if i < idx:
continue
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)
break

batch = collate_tiles([dataloader[idx]])
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)

net.eval()
mask_pred = net(x, y, z).detach().numpy().copy()
Expand Down Expand Up @@ -134,8 +133,15 @@ def plot_predictions(net, dataloader, idx=0, threshold=0.5):
pred_plot[pred_plot == c] = (i % 16) + 4

fig, axarr = plt.subplots(
3, x.shape[0], figsize=(5 * x.shape[0], 10), sharex=True, sharey=True
nrows=3,
ncols=x.shape[0],
figsize=(5 * x.shape[0], 10),
sharex=True,
sharey=True,
)
if len(axarr.shape) == 1:
axarr = axarr[:, None]

for i in range(x.shape[0]):
axarr[0, i].set_title("Labels and transcripts")
axarr[0, i].imshow(label_plot[i], cmap="tab20b", interpolation="none")
Expand Down
8 changes: 6 additions & 2 deletions src/nuc2seg/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from matplotlib import pyplot as plt


def plot_tiling(bboxes):
def plot_tiling(bboxes, output_path):
polygons = [box(*x) for x in bboxes]
gdf = geopandas.GeoDataFrame(geometry=polygons)

fig, ax = plt.subplots()

gdf.boundary.plot(ax=ax, color="red", alpha=0.5)

fig.savefig("/tmp/tiling.pdf")
fig.savefig(output_path)


def plot_model_predictions():
pass
120 changes: 49 additions & 71 deletions src/nuc2seg/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,21 @@
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.special import expit, softmax

from xenium_utils import pol2cart, create_pixel_geodf, load_nuclei
from nuc2seg.preprocessing import pol2cart
from nuc2seg.data import collate_tiles, ModelPredictions

logger = logging.getLogger(__name__)


def temp_forward(model, x, y, z):
mask = z > -1
b = torch.as_tensor(
np.tile(np.arange(z.shape[0]), (z.shape[1], 1)).T[mask.numpy().astype(bool)]
)
W = model.filters(z[mask])
t_input = torch.Tensor(np.zeros((z.shape[0],) + model.img_shape))
t_input.index_put_(
(b, torch.LongTensor(x[mask]), torch.LongTensor(y[mask])), W, accumulate=True
)
t_input = torch.Tensor.permute(
t_input, (0, 3, 1, 2)
) # Needs to be Batch x Channels x ImageX x ImageY
return torch.Tensor.permute(
model.unet(t_input), (0, 2, 3, 1)
) # Map back to Batch x ImageX x Image Y x Classes


def stitch_tile_predictions(model, dataset, tile_buffer=8):
"""TODO: all of the metadata info should be available in dataset."""
def stitch_predictions(model, dataloader):
model.eval()
foreground_list = []
class_list = []

x_max, y_max = dataset.locations.max(axis=0)
tile_width, tile_height = dataset[0]["labels"].numpy().shape

results = np.zeros((x_max + tile_width, y_max + tile_height, dataset.n_classes + 2))
for idx in tqdm.trange(len(dataset), desc="Stitching tiles"):
tile = dataset[idx]
vector_x_list = []
vector_y_list = []
for batch in tqdm.tqdm(dataloader, desc="Stitching predictions", unit="batch"):
tile = collate_tiles([batch])

x, y, z, labels, angles, classes, label_mask, nucleus_mask, location = (
tile["X"],
Expand All @@ -53,53 +34,50 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8):
tile["nucleus_mask"].numpy().copy().astype(bool),
tile["location"],
)

# mask_pred = model(x,y,z).detach().numpy().copy()
mask_pred = (
temp_forward(model, x[None], y[None], z[None])
.squeeze(0)
.detach()
.numpy()
.copy()
) # TEMP: code is fixed but i am currently using a pretrained model
mask_pred = model(x, y, z).detach().numpy().copy()
foreground_pred = expit(mask_pred[..., 0])
angles_pred = expit(mask_pred[..., 1]) * 2 * np.pi - np.pi
class_pred = softmax(mask_pred[..., 2:], axis=-1)
vector_x_list.append(0.5 * np.cos(angles_pred.squeeze()))
vector_y_list.append(0.5 * np.sin(angles_pred.squeeze()))
foreground_list.append(foreground_pred.squeeze())
class_list.append(class_pred.squeeze())

all_vector_x = torch.tensor(np.stack(vector_x_list, axis=0))
all_vector_y = torch.tensor(np.stack(vector_y_list, axis=0))

# Get the location of this tile in the whole slide
x_start, y_start = location

# Figure out which parts of the tile to use since the tiles overlap
x_end_offset, y_end_offset = foreground_pred.shape[:2]
x_start_offset, y_start_offset = 0, 0
if x_start > 0:
x_start_offset += tile_buffer
if y_start > 0:
y_start_offset += tile_buffer
if x_start < x_max:
x_end_offset -= tile_buffer
if y_start < y_max:
y_end_offset -= tile_buffer

results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
0,
] = foreground_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]
results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
1,
] = angles_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]
results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
2:,
] = class_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]

idx += 1

return results
all_foreground = torch.tensor(np.stack(foreground_list, axis=0))
all_classes = torch.tensor(np.stack(class_list, axis=0))

tile_mask = dataloader.tiler.get_tile_masks()[:, 0, :, :]

vector_x_tiles = all_vector_x * tile_mask
vector_x_stitched = dataloader.tiler.rebuild(
vector_x_tiles[:, None, :, :]
).squeeze()

vector_y_tiles = all_vector_y * tile_mask
vector_y_stitched = dataloader.tiler.rebuild(
vector_y_tiles[:, None, :, :]
).squeeze()

angles_stitched = torch.atan2(vector_y_stitched, vector_x_stitched)

foreground_tiles = all_foreground * tile_mask
foreground_stitched = dataloader.tiler.rebuild(
foreground_tiles[:, None, :, :]
).squeeze()

class_tiles = all_classes * tile_mask[..., None]
class_stitched = dataloader.tiler.rebuild(
class_tiles.permute((0, 3, 1, 2))
).squeeze()

return ModelPredictions(
angles=angles_stitched.detach().numpy(),
foreground=foreground_stitched.detach().numpy(),
classes=class_stitched.detach().numpy(),
)


def greedy_expansion(
Expand Down
Loading