Skip to content

Commit

Permalink
Merge pull request #4 from tansey-lab/jq_segmentation
Browse files Browse the repository at this point in the history
Add Prediction Top Level CLI and Method
  • Loading branch information
jeffquinn-msk authored Feb 13, 2024
2 parents 69efc90 + e82cc73 commit ebca995
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 90 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]

[project.scripts]
preprocess = "nuc2seg.cli.preprocess:main"
train_nuc2seg = "nuc2seg.cli.train:main"
train = "nuc2seg.cli.train:main"
predict = "nuc2seg.cli.predict:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
94 changes: 94 additions & 0 deletions src/nuc2seg/cli/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import argparse
import logging
import numpy as np
import torch

from nuc2seg import log_config
from nuc2seg.segment import stitch_predictions
from nuc2seg.unet_model import SparseUNet
from nuc2seg.data import Nuc2SegDataset, TiledDataset

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Evaluate a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--output",
help="Model prediction output in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--model-weights",
help="File to read model weights from.",
type=str,
required=True,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--num-dataloader-workers",
help="Number of workers to use for the data loader.",
type=int,
default=0,
)
return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)

model = SparseUNet(600, ds.n_classes + 2, (args.tile_height, args.tile_width))

model.load_state_dict(torch.load(args.model_weights))

model_predictions = stitch_predictions(model=model, dataloader=tiled_dataset)

model_predictions.save_h5(args.output)
28 changes: 27 additions & 1 deletion src/nuc2seg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def __init__(
def __len__(self):
return self._tiler.num_tiles()

@property
def tiler(self):
return self._tiler

@property
def per_tile_class_histograms(self):
class_tiles = (
Expand All @@ -160,7 +164,7 @@ def per_tile_class_histograms(self):
def __getitem__(self, idx):
x1, y1, x2, y2 = next(
generate_tiles(
tiler=self._tiler,
tiler=self.tiler,
x_extent=self.ds.x_extent_pixels,
y_extent=self.ds.y_extent_pixels,
tile_size=(self.tile_height, self.tile_width),
Expand Down Expand Up @@ -211,3 +215,25 @@ def __getitem__(self, idx):
"nucleus_mask": torch.as_tensor(nucleus_mask).bool().contiguous(),
"location": np.array([x1, y1]),
}


class ModelPredictions:
def __init__(self, angles, classes, foreground):

self.angles = angles
self.classes = classes
self.foreground = foreground

def save_h5(self, path):
with h5py.File(path, "w") as f:
f.create_dataset("angles", data=self.angles, compression="gzip")
f.create_dataset("classes", data=self.classes, compression="gzip")
f.create_dataset("foreground", data=self.foreground, compression="gzip")

@staticmethod
def load_h5(path):
with h5py.File(path, "r") as f:
angles = f["angles"][:]
classes = f["classes"][:]
foreground = f["foreground"][:]
return ModelPredictions(angles=angles, classes=classes, foreground=foreground)
36 changes: 21 additions & 15 deletions src/nuc2seg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tqdm

from nuc2seg.preprocessing import pol2cart
from nuc2seg.data import collate_tiles


def dice_coeff(
Expand Down Expand Up @@ -88,20 +89,18 @@ def evaluate(net, dataloader, device, amp):


def plot_predictions(net, dataloader, idx=0, threshold=0.5):
for i, batch in enumerate(dataloader):
if i < idx:
continue
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)
break

batch = collate_tiles([dataloader[idx]])
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
batch["gene"],
batch["labels"].numpy().copy().astype(int),
batch["angles"].numpy().copy().astype(float),
batch["classes"].numpy().copy().astype(int),
batch["label_mask"].numpy().copy().astype(bool),
batch["nucleus_mask"].numpy().copy().astype(bool),
)

net.eval()
mask_pred = net(x, y, z).detach().numpy().copy()
Expand Down Expand Up @@ -134,8 +133,15 @@ def plot_predictions(net, dataloader, idx=0, threshold=0.5):
pred_plot[pred_plot == c] = (i % 16) + 4

fig, axarr = plt.subplots(
3, x.shape[0], figsize=(5 * x.shape[0], 10), sharex=True, sharey=True
nrows=3,
ncols=x.shape[0],
figsize=(5 * x.shape[0], 10),
sharex=True,
sharey=True,
)
if len(axarr.shape) == 1:
axarr = axarr[:, None]

for i in range(x.shape[0]):
axarr[0, i].set_title("Labels and transcripts")
axarr[0, i].imshow(label_plot[i], cmap="tab20b", interpolation="none")
Expand Down
8 changes: 6 additions & 2 deletions src/nuc2seg/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from matplotlib import pyplot as plt


def plot_tiling(bboxes):
def plot_tiling(bboxes, output_path):
polygons = [box(*x) for x in bboxes]
gdf = geopandas.GeoDataFrame(geometry=polygons)

fig, ax = plt.subplots()

gdf.boundary.plot(ax=ax, color="red", alpha=0.5)

fig.savefig("/tmp/tiling.pdf")
fig.savefig(output_path)


def plot_model_predictions():
pass
120 changes: 49 additions & 71 deletions src/nuc2seg/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,21 @@
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.special import expit, softmax

from xenium_utils import pol2cart, create_pixel_geodf, load_nuclei
from nuc2seg.preprocessing import pol2cart
from nuc2seg.data import collate_tiles, ModelPredictions

logger = logging.getLogger(__name__)


def temp_forward(model, x, y, z):
mask = z > -1
b = torch.as_tensor(
np.tile(np.arange(z.shape[0]), (z.shape[1], 1)).T[mask.numpy().astype(bool)]
)
W = model.filters(z[mask])
t_input = torch.Tensor(np.zeros((z.shape[0],) + model.img_shape))
t_input.index_put_(
(b, torch.LongTensor(x[mask]), torch.LongTensor(y[mask])), W, accumulate=True
)
t_input = torch.Tensor.permute(
t_input, (0, 3, 1, 2)
) # Needs to be Batch x Channels x ImageX x ImageY
return torch.Tensor.permute(
model.unet(t_input), (0, 2, 3, 1)
) # Map back to Batch x ImageX x Image Y x Classes


def stitch_tile_predictions(model, dataset, tile_buffer=8):
"""TODO: all of the metadata info should be available in dataset."""
def stitch_predictions(model, dataloader):
model.eval()
foreground_list = []
class_list = []

x_max, y_max = dataset.locations.max(axis=0)
tile_width, tile_height = dataset[0]["labels"].numpy().shape

results = np.zeros((x_max + tile_width, y_max + tile_height, dataset.n_classes + 2))
for idx in tqdm.trange(len(dataset), desc="Stitching tiles"):
tile = dataset[idx]
vector_x_list = []
vector_y_list = []
for batch in tqdm.tqdm(dataloader, desc="Stitching predictions", unit="batch"):
tile = collate_tiles([batch])

x, y, z, labels, angles, classes, label_mask, nucleus_mask, location = (
tile["X"],
Expand All @@ -53,53 +34,50 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8):
tile["nucleus_mask"].numpy().copy().astype(bool),
tile["location"],
)

# mask_pred = model(x,y,z).detach().numpy().copy()
mask_pred = (
temp_forward(model, x[None], y[None], z[None])
.squeeze(0)
.detach()
.numpy()
.copy()
) # TEMP: code is fixed but i am currently using a pretrained model
mask_pred = model(x, y, z).detach().numpy().copy()
foreground_pred = expit(mask_pred[..., 0])
angles_pred = expit(mask_pred[..., 1]) * 2 * np.pi - np.pi
class_pred = softmax(mask_pred[..., 2:], axis=-1)
vector_x_list.append(0.5 * np.cos(angles_pred.squeeze()))
vector_y_list.append(0.5 * np.sin(angles_pred.squeeze()))
foreground_list.append(foreground_pred.squeeze())
class_list.append(class_pred.squeeze())

all_vector_x = torch.tensor(np.stack(vector_x_list, axis=0))
all_vector_y = torch.tensor(np.stack(vector_y_list, axis=0))

# Get the location of this tile in the whole slide
x_start, y_start = location

# Figure out which parts of the tile to use since the tiles overlap
x_end_offset, y_end_offset = foreground_pred.shape[:2]
x_start_offset, y_start_offset = 0, 0
if x_start > 0:
x_start_offset += tile_buffer
if y_start > 0:
y_start_offset += tile_buffer
if x_start < x_max:
x_end_offset -= tile_buffer
if y_start < y_max:
y_end_offset -= tile_buffer

results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
0,
] = foreground_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]
results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
1,
] = angles_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]
results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
2:,
] = class_pred[x_start_offset:x_end_offset, y_start_offset:y_end_offset]

idx += 1

return results
all_foreground = torch.tensor(np.stack(foreground_list, axis=0))
all_classes = torch.tensor(np.stack(class_list, axis=0))

tile_mask = dataloader.tiler.get_tile_masks()[:, 0, :, :]

vector_x_tiles = all_vector_x * tile_mask
vector_x_stitched = dataloader.tiler.rebuild(
vector_x_tiles[:, None, :, :]
).squeeze()

vector_y_tiles = all_vector_y * tile_mask
vector_y_stitched = dataloader.tiler.rebuild(
vector_y_tiles[:, None, :, :]
).squeeze()

angles_stitched = torch.atan2(vector_y_stitched, vector_x_stitched)

foreground_tiles = all_foreground * tile_mask
foreground_stitched = dataloader.tiler.rebuild(
foreground_tiles[:, None, :, :]
).squeeze()

class_tiles = all_classes * tile_mask[..., None]
class_stitched = dataloader.tiler.rebuild(
class_tiles.permute((0, 3, 1, 2))
).squeeze()

return ModelPredictions(
angles=angles_stitched.detach().numpy(),
foreground=foreground_stitched.detach().numpy(),
classes=class_stitched.detach().numpy(),
)


def greedy_expansion(
Expand Down

0 comments on commit ebca995

Please sign in to comment.