Skip to content

Commit

Permalink
Merge pull request #2 from tansey-lab/jq_fix_logging
Browse files Browse the repository at this point in the history
Replace Print Statements with Logging/TQDM
  • Loading branch information
jeffquinn-msk authored Feb 9, 2024
2 parents 388966c + 052f96b commit 26ef3aa
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 51 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ test = ["pytest", "pytest-mock", "tox", "coverage"]

[project.scripts]
preprocess = "nuc2seg.cli.preprocess:main"
train_nuc2seg = "nuc2seg.cli.train:main"

[build-system]
requires = ["setuptools>=43.0.0", "wheel"]
Expand Down
153 changes: 153 additions & 0 deletions src/nuc2seg/cli/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
import logging
import numpy as np
import torch

from nuc2seg import log_config
from nuc2seg.train import train
from nuc2seg.unet_model import SparseUNet

logger = logging.getLogger(__name__)


def get_parser():
parser = argparse.ArgumentParser(
description="Train a UNet model on preprocessed data."
)
log_config.add_logging_args(parser)
parser.add_argument(
"--preprocessed-tiles-dir",
help="Directory containing preprocessed tiles.",
type=str,
required=True,
)
parser.add_argument(
"--model-weights-output",
help="File to save model weights to.",
type=str,
required=True,
)
parser.add_argument(
"--n-classes",
help="Number of classes to segment.",
type=int,
required=True,
)
parser.add_argument(
"--seed",
help="Seed to use for PRNG.",
type=int,
default=0,
)
parser.add_argument(
"--epochs",
help="Number of epochs to train for.",
type=int,
default=50,
)
parser.add_argument(
"--batch-size",
help="Batch size.",
type=int,
default=1,
)
parser.add_argument(
"--learning-rate",
help="Learning rate.",
type=float,
default=1e-5,
)
parser.add_argument(
"--val-percent",
help="Percentage of data to use for validation.",
type=float,
default=0.1,
)
parser.add_argument(
"--save-checkpoint",
help="Save model checkpoint.",
action="store_true",
default=True,
)
parser.add_argument(
"--amp",
help="Use automatic mixed precision.",
action="store_true",
default=False,
)
parser.add_argument(
"--weight-decay",
help="Weight decay.",
type=float,
default=1e-8,
)
parser.add_argument(
"--momentum",
help="Momentum.",
type=float,
default=0.999,
)
parser.add_argument(
"--gradient-clipping",
help="Gradient clipping.",
type=float,
default=1.0,
)
parser.add_argument(
"--validation-frequency",
help="Frequency of validation.",
type=int,
default=500,
)
parser.add_argument(
"--max-workers",
help="Maximum number of workers to use for data loading.",
type=int,
default=1,
)
parser.add_argument(
"--device",
help="Device to use for training.",
type=str,
default="cpu",
choices=["cpu", "cuda"],
)

return parser


def get_args():
parser = get_parser()

args = parser.parse_args()

return args


def main():
args = get_args()

log_config.configure_logging(args)

np.random.seed(args.seed)
model = SparseUNet(600, args.n_classes + 2, (64, 64))

train(
model,
device=args.device,
tiles_dir=args.preprocessed_tiles_dir,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
val_percent=args.val_percent,
save_checkpoint=args.save_checkpoint,
amp=args.amp,
weight_decay=args.weight_decay,
momentum=args.momentum,
gradient_clipping=args.gradient_clipping,
max_workers=args.max_workers,
validation_frequency=args.validation_frequency,
)

logger.info(f"Saving model weights to {args.model_weights_output}")
torch.save(model.state_dict(), args.model_weights_output)
12 changes: 6 additions & 6 deletions src/nuc2seg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from scipy.special import softmax, expit
from torch import Tensor
from matplotlib import pyplot as plt
import tqdm

from nuc2seg.xenium_utils import pol2cart


Expand Down Expand Up @@ -48,9 +50,9 @@ def evaluate(net, dataloader, device, amp):
# iterate over the validation set
# with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
# for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
for idx, batch in enumerate(dataloader):
if idx % 10 == 0:
print(f"{idx+1}/{num_val_batches}")
for idx, batch in enumerate(
tqdm.tqdm(dataloader, desc="Validation", unit="batch", position=3)
):
x, y, z, labels, label_mask = (
batch["X"],
batch["Y"],
Expand Down Expand Up @@ -193,9 +195,7 @@ def score_segmentation(segments, nuclei):
percent_common = np.zeros(n_nuclei)
nuclei_label_counts = np.zeros(n_nuclei)
label_nuclei_counts = np.zeros(np.unique(nuclei_segments)[-1] + 1)
for i in range(n_nuclei):
if i % 100 == 0:
print(f"{i+1}/{n_nuclei}")
for i in tqdm.trange(n_nuclei, desc="Scoring nuclei"):

local_labels = nuclei_segments[inv_map == i]
local_uniques, local_counts = np.unique(
Expand Down
22 changes: 8 additions & 14 deletions src/nuc2seg/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import numpy as np
import geopandas
import pandas

import tqdm
import logging
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components
from scipy.special import expit, softmax

from xenium_utils import pol2cart, create_pixel_geodf, load_nuclei

logger = logging.getLogger(__name__)


def temp_forward(model, x, y, z):
mask = z > -1
Expand Down Expand Up @@ -36,9 +39,7 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8):
tile_width, tile_height = dataset[0]["labels"].numpy().shape

results = np.zeros((x_max + tile_width, y_max + tile_height, dataset.n_classes + 2))
for idx in range(len(dataset)):
if idx % 100 == 0:
print(f"{idx}/{len(dataset)}")
for idx in tqdm.trange(len(dataset), desc="Stitching tiles"):
tile = dataset[idx]

x, y, z, labels, angles, classes, label_mask, nucleus_mask, location = (
Expand Down Expand Up @@ -80,12 +81,6 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8):
if y_start < y_max:
y_end_offset -= tile_buffer

print(
x_start + x_start_offset,
x_start + x_end_offset,
y_start + y_start_offset,
y_start + y_end_offset,
)
results[
x_start + x_start_offset : x_start + x_end_offset,
y_start + y_start_offset : y_start + y_end_offset,
Expand Down Expand Up @@ -115,8 +110,7 @@ def greedy_expansion(
foreground_mask,
max_expansion_steps=50,
):
for step in range(max_expansion_steps):
print(f"Expansion step {step+1}/{max_expansion_steps}")
for step in tqdm.trange(max_expansion_steps, desc="greedy_expansion", unit="step"):
# Filter down to unassigned pixels that would flow to an assigned pixel
pixel_labels_flat = pixel_labels_arr[start_xy[:, 0], start_xy[:, 1]]
flow_labels_flat = flow_labels[start_xy[:, 0], start_xy[:, 1]]
Expand All @@ -129,7 +123,7 @@ def greedy_expansion(

# If there are no pixels to update, just exit early
if update_mask.sum() == 0:
print("No pixels left to update. Stopping expansion.")
logger.debug("No pixels left to update. Stopping expansion.")
break

# Update the filtered pixels to have the assignment of their flow neighbor
Expand Down Expand Up @@ -373,7 +367,7 @@ def greedy_cell_segmentation(
# Update the pixel labels with the new cells
n_nuclei = pixel_labels_arr.max() + 1
for offset, connected_label in enumerate(uniques):
print(f"Segmenting cells without nuclei ({offset+1}/{len(uniques)}")
logger.info(f"Segmenting cells without nuclei ({offset+1}/{len(uniques)}")
mask = connected_labels == connected_label
pixel_labels_arr[start_xy[mask, 0], start_xy[mask, 1]] = n_nuclei + offset

Expand Down
16 changes: 6 additions & 10 deletions src/nuc2seg/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import torch.nn as nn
import tqdm

from nuc2seg.data_loading import XeniumDataset, xenium_collate_fn
from torch import optim
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -45,6 +47,7 @@ def train(
momentum: float = 0.999,
gradient_clipping: float = 1.0,
max_workers: int = 1,
validation_frequency: int = 500,
):
# Create the dataset
dataset = XeniumDataset(tiles_dir)
Expand Down Expand Up @@ -96,11 +99,11 @@ def train(
validation_scores = []

# 5. Begin training
for epoch in range(1, epochs + 1):
for epoch in tqdm.trange(1, epochs + 1, position=0, desc="Epoch"):
model.train()
epoch_loss = 0
# with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
for batch in tqdm.tqdm(train_loader, position=1, desc="Batch"):
x, y, z, labels, angles, classes, label_mask, nucleus_mask = (
batch["X"],
batch["Y"],
Expand Down Expand Up @@ -141,10 +144,6 @@ def train(
class_pred[nucleus_mask], classes[nucleus_mask] - 1
)

# TODO: include DICE coefficient in loss?

print(loss)

# Backprop
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
Expand All @@ -157,11 +156,8 @@ def train(
epoch_loss += loss.item()

# for Evaluating model performance/ convergence
if global_step % 500 == 0:
if global_step % validation_frequency == 0:
validation_score = evaluate(model, val_loader, device, amp)
print("Previous validation scores:")
print(validation_scores)
print(f"Current: {validation_score:.2f}")
validation_scores.append(validation_score)


Expand Down
Loading

0 comments on commit 26ef3aa

Please sign in to comment.