Skip to content

Commit

Permalink
Add Plotting Code (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffquinn-msk authored Feb 13, 2024
1 parent ebca995 commit d1afce8
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]
preprocess = "nuc2seg.cli.preprocess:main"
train = "nuc2seg.cli.train:main"
predict = "nuc2seg.cli.predict:main"
plot_predictions = "nuc2seg.cli.plot_predictions:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
100 changes: 100 additions & 0 deletions src/nuc2seg/cli/plot_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import logging
import os.path
import tqdm

from nuc2seg import log_config
from nuc2seg.data import Nuc2SegDataset, TiledDataset, ModelPredictions, generate_tiles
from nuc2seg.plotting import plot_model_predictions

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Evaluate a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--predictions",
help="Model prediction output in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--output-dir",
help="Output directory for plots.",
type=str,
required=True,
)
return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)
predictions = ModelPredictions.load_h5(args.predictions)

tile_generator = generate_tiles(
tiler=tiled_dataset.tiler,
x_extent=ds.x_extent_pixels,
y_extent=ds.y_extent_pixels,
tile_size=(args.tile_width, args.tile_height),
overlap_fraction=args.overlap_percentage,
)

os.makedirs(args.output_dir, exist_ok=True)

for bbox in tqdm.tqdm(tile_generator, total=len(tiled_dataset), unit="plot"):
plot_model_predictions(
dataset=ds,
model_predictions=predictions,
bbox=bbox,
output_path=os.path.join(
args.output_dir, "_".join([str(x) for x in bbox]) + ".pdf"
),
)
70 changes: 68 additions & 2 deletions src/nuc2seg/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import geopandas
from shapely import box
from matplotlib import pyplot as plt
from nuc2seg.data import Nuc2SegDataset, ModelPredictions
import numpy as np
import math


def plot_tiling(bboxes, output_path):
Expand All @@ -14,5 +17,68 @@ def plot_tiling(bboxes, output_path):
fig.savefig(output_path)


def plot_model_predictions():
pass
def plot_labels(ax, dataset: Nuc2SegDataset, bbox=None):
label_plot = dataset.labels.copy()
transcripts = dataset.transcripts.copy()
label_plot[label_plot >= 1] = 5
label_plot[label_plot == -1] = 2

if bbox is not None:
label_plot = label_plot[bbox[0] : bbox[2], bbox[1] : bbox[3]]
mask_x = (transcripts[:, 0] >= bbox[0]) & (transcripts[:, 0] < bbox[2])
mask_y = (transcripts[:, 1] >= bbox[1]) & (transcripts[:, 1] < bbox[3])
mask = mask_x & mask_y
transcripts = transcripts[mask]

transcripts[:, 0] = transcripts[:, 0] - bbox[0]
transcripts[:, 1] = transcripts[:, 1] - bbox[1]

ax.set_title("Labels and transcripts")
ax.imshow(label_plot, cmap="copper", interpolation="none")

ax.scatter(
transcripts[:, 1], transcripts[:, 0], color="red", zorder=100, s=0.1, alpha=0.3
)


def plot_angles(ax, predictions: ModelPredictions, skip_factor=1, bbox=None):
angles = predictions.angles

if bbox is not None:
angles = angles[bbox[0] : bbox[2], bbox[1] : bbox[3]]

skip = (slice(None, None, skip_factor), slice(None, None, skip_factor))

ax.set_title("Predicted angles")

U = 0.5 * np.cos(angles)
V = 0.5 * np.sin(angles)
Y, X = np.mgrid[0 : U.shape[0], 0 : U.shape[1]]

ax.quiver(X[skip], Y[skip], U[skip], V[skip], color="black", headwidth=2)


def plot_foreground(ax, predictions: ModelPredictions, bbox=None):
foreground = predictions.foreground

if bbox is not None:
foreground = foreground[bbox[0] : bbox[2], bbox[1] : bbox[3]]

ax.imshow(foreground, vmin=0, vmax=1, cmap="coolwarm", interpolation="none")
ax.set_title("Predicted foreground")


def plot_model_predictions(
dataset: Nuc2SegDataset,
model_predictions: ModelPredictions,
output_path=None,
bbox=None,
):
fig, ax = plt.subplots(figsize=(10, 10), nrows=3, dpi=1000)

plot_labels(ax[0], dataset, bbox=bbox)
plot_angles(ax[1], model_predictions, bbox=bbox)
plot_foreground(ax[2], model_predictions, bbox=bbox)

fig.tight_layout()
fig.savefig(output_path)

0 comments on commit d1afce8

Please sign in to comment.