Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Plotting Code #5

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]
preprocess = "nuc2seg.cli.preprocess:main"
train = "nuc2seg.cli.train:main"
predict = "nuc2seg.cli.predict:main"
plot_predictions = "nuc2seg.cli.plot_predictions:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
100 changes: 100 additions & 0 deletions src/nuc2seg/cli/plot_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import logging
import os.path
import tqdm

from nuc2seg import log_config
from nuc2seg.data import Nuc2SegDataset, TiledDataset, ModelPredictions, generate_tiles
from nuc2seg.plotting import plot_model_predictions

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Evaluate a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--predictions",
help="Model prediction output in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--dataset",
help="Path to dataset in h5 format.",
type=str,
required=True,
)
parser.add_argument(
"--tile-height",
help="Height of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--tile-width",
help="Width of the tiles.",
type=int,
default=64,
)
parser.add_argument(
"--overlap-percentage",
help="What percent of each tile dimension overlaps with the next tile.",
type=float,
default=0.25,
)
parser.add_argument(
"--output-dir",
help="Output directory for plots.",
type=str,
required=True,
)
return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

logger.info(f"Loading dataset from {args.dataset}")

ds = Nuc2SegDataset.load_h5(args.dataset)

tiled_dataset = TiledDataset(
ds,
tile_height=args.tile_height,
tile_width=args.tile_width,
tile_overlap=args.overlap_percentage,
)
predictions = ModelPredictions.load_h5(args.predictions)

tile_generator = generate_tiles(
tiler=tiled_dataset.tiler,
x_extent=ds.x_extent_pixels,
y_extent=ds.y_extent_pixels,
tile_size=(args.tile_width, args.tile_height),
overlap_fraction=args.overlap_percentage,
)

os.makedirs(args.output_dir, exist_ok=True)

for bbox in tqdm.tqdm(tile_generator, total=len(tiled_dataset), unit="plot"):
plot_model_predictions(
dataset=ds,
model_predictions=predictions,
bbox=bbox,
output_path=os.path.join(
args.output_dir, "_".join([str(x) for x in bbox]) + ".pdf"
),
)
70 changes: 68 additions & 2 deletions src/nuc2seg/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import geopandas
from shapely import box
from matplotlib import pyplot as plt
from nuc2seg.data import Nuc2SegDataset, ModelPredictions
import numpy as np
import math


def plot_tiling(bboxes, output_path):
Expand All @@ -14,5 +17,68 @@ def plot_tiling(bboxes, output_path):
fig.savefig(output_path)


def plot_model_predictions():
pass
def plot_labels(ax, dataset: Nuc2SegDataset, bbox=None):
label_plot = dataset.labels.copy()
transcripts = dataset.transcripts.copy()
label_plot[label_plot >= 1] = 5
label_plot[label_plot == -1] = 2

if bbox is not None:
label_plot = label_plot[bbox[0] : bbox[2], bbox[1] : bbox[3]]
mask_x = (transcripts[:, 0] >= bbox[0]) & (transcripts[:, 0] < bbox[2])
mask_y = (transcripts[:, 1] >= bbox[1]) & (transcripts[:, 1] < bbox[3])
mask = mask_x & mask_y
transcripts = transcripts[mask]

transcripts[:, 0] = transcripts[:, 0] - bbox[0]
transcripts[:, 1] = transcripts[:, 1] - bbox[1]

ax.set_title("Labels and transcripts")
ax.imshow(label_plot, cmap="copper", interpolation="none")

ax.scatter(
transcripts[:, 1], transcripts[:, 0], color="red", zorder=100, s=0.1, alpha=0.3
)


def plot_angles(ax, predictions: ModelPredictions, skip_factor=1, bbox=None):
angles = predictions.angles

if bbox is not None:
angles = angles[bbox[0] : bbox[2], bbox[1] : bbox[3]]

skip = (slice(None, None, skip_factor), slice(None, None, skip_factor))

ax.set_title("Predicted angles")

U = 0.5 * np.cos(angles)
V = 0.5 * np.sin(angles)
Y, X = np.mgrid[0 : U.shape[0], 0 : U.shape[1]]

ax.quiver(X[skip], Y[skip], U[skip], V[skip], color="black", headwidth=2)


def plot_foreground(ax, predictions: ModelPredictions, bbox=None):
foreground = predictions.foreground

if bbox is not None:
foreground = foreground[bbox[0] : bbox[2], bbox[1] : bbox[3]]

ax.imshow(foreground, vmin=0, vmax=1, cmap="coolwarm", interpolation="none")
ax.set_title("Predicted foreground")


def plot_model_predictions(
dataset: Nuc2SegDataset,
model_predictions: ModelPredictions,
output_path=None,
bbox=None,
):
fig, ax = plt.subplots(figsize=(10, 10), nrows=3, dpi=1000)

plot_labels(ax[0], dataset, bbox=bbox)
plot_angles(ax[1], model_predictions, bbox=bbox)
plot_foreground(ax[2], model_predictions, bbox=bbox)

fig.tight_layout()
fig.savefig(output_path)
Loading