diff --git a/pyproject.toml b/pyproject.toml index f0d26fc..3b095a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ test = ["pytest", "pytest-mock", "tox", "coverage"] [project.scripts] preprocess = "nuc2seg.cli.preprocess:main" +train_nuc2seg = "nuc2seg.cli.train:main" [build-system] requires = ["setuptools>=43.0.0", "wheel"] diff --git a/src/nuc2seg/cli/train.py b/src/nuc2seg/cli/train.py new file mode 100644 index 0000000..d5eea60 --- /dev/null +++ b/src/nuc2seg/cli/train.py @@ -0,0 +1,153 @@ +import argparse +import logging +import numpy as np +import torch + +from nuc2seg import log_config +from nuc2seg.train import train +from nuc2seg.unet_model import SparseUNet + +logger = logging.getLogger(__name__) + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Train a UNet model on preprocessed data." + ) + log_config.add_logging_args(parser) + parser.add_argument( + "--preprocessed-tiles-dir", + help="Directory containing preprocessed tiles.", + type=str, + required=True, + ) + parser.add_argument( + "--model-weights-output", + help="File to save model weights to.", + type=str, + required=True, + ) + parser.add_argument( + "--n-classes", + help="Number of classes to segment.", + type=int, + required=True, + ) + parser.add_argument( + "--seed", + help="Seed to use for PRNG.", + type=int, + default=0, + ) + parser.add_argument( + "--epochs", + help="Number of epochs to train for.", + type=int, + default=50, + ) + parser.add_argument( + "--batch-size", + help="Batch size.", + type=int, + default=1, + ) + parser.add_argument( + "--learning-rate", + help="Learning rate.", + type=float, + default=1e-5, + ) + parser.add_argument( + "--val-percent", + help="Percentage of data to use for validation.", + type=float, + default=0.1, + ) + parser.add_argument( + "--save-checkpoint", + help="Save model checkpoint.", + action="store_true", + default=True, + ) + parser.add_argument( + "--amp", + help="Use automatic mixed precision.", + action="store_true", + default=False, + ) + parser.add_argument( + "--weight-decay", + help="Weight decay.", + type=float, + default=1e-8, + ) + parser.add_argument( + "--momentum", + help="Momentum.", + type=float, + default=0.999, + ) + parser.add_argument( + "--gradient-clipping", + help="Gradient clipping.", + type=float, + default=1.0, + ) + parser.add_argument( + "--validation-frequency", + help="Frequency of validation.", + type=int, + default=500, + ) + parser.add_argument( + "--max-workers", + help="Maximum number of workers to use for data loading.", + type=int, + default=1, + ) + parser.add_argument( + "--device", + help="Device to use for training.", + type=str, + default="cpu", + choices=["cpu", "cuda"], + ) + + return parser + + +def get_args(): + parser = get_parser() + + args = parser.parse_args() + + return args + + +def main(): + args = get_args() + + log_config.configure_logging(args) + + np.random.seed(args.seed) + model = SparseUNet(600, args.n_classes + 2, (64, 64)) + + train( + model, + device=args.device, + tiles_dir=args.preprocessed_tiles_dir, + epochs=args.epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + val_percent=args.val_percent, + save_checkpoint=args.save_checkpoint, + amp=args.amp, + weight_decay=args.weight_decay, + momentum=args.momentum, + gradient_clipping=args.gradient_clipping, + max_workers=args.max_workers, + validation_frequency=args.validation_frequency, + ) + + logger.info(f"Saving model weights to {args.model_weights_output}") + torch.save(model.state_dict(), args.model_weights_output) diff --git a/src/nuc2seg/evaluate.py b/src/nuc2seg/evaluate.py index 673cd28..6e64cd9 100755 --- a/src/nuc2seg/evaluate.py +++ b/src/nuc2seg/evaluate.py @@ -3,6 +3,8 @@ from scipy.special import softmax, expit from torch import Tensor from matplotlib import pyplot as plt +import tqdm + from nuc2seg.xenium_utils import pol2cart @@ -48,9 +50,9 @@ def evaluate(net, dataloader, device, amp): # iterate over the validation set # with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): # for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False): - for idx, batch in enumerate(dataloader): - if idx % 10 == 0: - print(f"{idx+1}/{num_val_batches}") + for idx, batch in enumerate( + tqdm.tqdm(dataloader, desc="Validation", unit="batch", position=3) + ): x, y, z, labels, label_mask = ( batch["X"], batch["Y"], @@ -193,9 +195,7 @@ def score_segmentation(segments, nuclei): percent_common = np.zeros(n_nuclei) nuclei_label_counts = np.zeros(n_nuclei) label_nuclei_counts = np.zeros(np.unique(nuclei_segments)[-1] + 1) - for i in range(n_nuclei): - if i % 100 == 0: - print(f"{i+1}/{n_nuclei}") + for i in tqdm.trange(n_nuclei, desc="Scoring nuclei"): local_labels = nuclei_segments[inv_map == i] local_uniques, local_counts = np.unique( diff --git a/src/nuc2seg/segment.py b/src/nuc2seg/segment.py index e3d0bb4..4c9e30d 100755 --- a/src/nuc2seg/segment.py +++ b/src/nuc2seg/segment.py @@ -2,13 +2,16 @@ import numpy as np import geopandas import pandas - +import tqdm +import logging from scipy.sparse import csr_matrix from scipy.sparse.csgraph import connected_components from scipy.special import expit, softmax from xenium_utils import pol2cart, create_pixel_geodf, load_nuclei +logger = logging.getLogger(__name__) + def temp_forward(model, x, y, z): mask = z > -1 @@ -36,9 +39,7 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8): tile_width, tile_height = dataset[0]["labels"].numpy().shape results = np.zeros((x_max + tile_width, y_max + tile_height, dataset.n_classes + 2)) - for idx in range(len(dataset)): - if idx % 100 == 0: - print(f"{idx}/{len(dataset)}") + for idx in tqdm.trange(len(dataset), desc="Stitching tiles"): tile = dataset[idx] x, y, z, labels, angles, classes, label_mask, nucleus_mask, location = ( @@ -80,12 +81,6 @@ def stitch_tile_predictions(model, dataset, tile_buffer=8): if y_start < y_max: y_end_offset -= tile_buffer - print( - x_start + x_start_offset, - x_start + x_end_offset, - y_start + y_start_offset, - y_start + y_end_offset, - ) results[ x_start + x_start_offset : x_start + x_end_offset, y_start + y_start_offset : y_start + y_end_offset, @@ -115,8 +110,7 @@ def greedy_expansion( foreground_mask, max_expansion_steps=50, ): - for step in range(max_expansion_steps): - print(f"Expansion step {step+1}/{max_expansion_steps}") + for step in tqdm.trange(max_expansion_steps, desc="greedy_expansion", unit="step"): # Filter down to unassigned pixels that would flow to an assigned pixel pixel_labels_flat = pixel_labels_arr[start_xy[:, 0], start_xy[:, 1]] flow_labels_flat = flow_labels[start_xy[:, 0], start_xy[:, 1]] @@ -129,7 +123,7 @@ def greedy_expansion( # If there are no pixels to update, just exit early if update_mask.sum() == 0: - print("No pixels left to update. Stopping expansion.") + logger.debug("No pixels left to update. Stopping expansion.") break # Update the filtered pixels to have the assignment of their flow neighbor @@ -373,7 +367,7 @@ def greedy_cell_segmentation( # Update the pixel labels with the new cells n_nuclei = pixel_labels_arr.max() + 1 for offset, connected_label in enumerate(uniques): - print(f"Segmenting cells without nuclei ({offset+1}/{len(uniques)}") + logger.info(f"Segmenting cells without nuclei ({offset+1}/{len(uniques)}") mask = connected_labels == connected_label pixel_labels_arr[start_xy[mask, 0], start_xy[mask, 1]] = n_nuclei + offset diff --git a/src/nuc2seg/train.py b/src/nuc2seg/train.py index 4c898f0..ea947d2 100755 --- a/src/nuc2seg/train.py +++ b/src/nuc2seg/train.py @@ -1,5 +1,7 @@ import torch import torch.nn as nn +import tqdm + from nuc2seg.data_loading import XeniumDataset, xenium_collate_fn from torch import optim from torch.utils.data import DataLoader, random_split @@ -45,6 +47,7 @@ def train( momentum: float = 0.999, gradient_clipping: float = 1.0, max_workers: int = 1, + validation_frequency: int = 500, ): # Create the dataset dataset = XeniumDataset(tiles_dir) @@ -96,11 +99,11 @@ def train( validation_scores = [] # 5. Begin training - for epoch in range(1, epochs + 1): + for epoch in tqdm.trange(1, epochs + 1, position=0, desc="Epoch"): model.train() epoch_loss = 0 # with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar: - for batch in train_loader: + for batch in tqdm.tqdm(train_loader, position=1, desc="Batch"): x, y, z, labels, angles, classes, label_mask, nucleus_mask = ( batch["X"], batch["Y"], @@ -141,10 +144,6 @@ def train( class_pred[nucleus_mask], classes[nucleus_mask] - 1 ) - # TODO: include DICE coefficient in loss? - - print(loss) - # Backprop optimizer.zero_grad(set_to_none=True) grad_scaler.scale(loss).backward() @@ -157,11 +156,8 @@ def train( epoch_loss += loss.item() # for Evaluating model performance/ convergence - if global_step % 500 == 0: + if global_step % validation_frequency == 0: validation_score = evaluate(model, val_loader, device, amp) - print("Previous validation scores:") - print(validation_scores) - print(f"Current: {validation_score:.2f}") validation_scores.append(validation_score) diff --git a/src/nuc2seg/xenium_utils.py b/src/nuc2seg/xenium_utils.py index 19a010b..c5bf0d3 100755 --- a/src/nuc2seg/xenium_utils.py +++ b/src/nuc2seg/xenium_utils.py @@ -135,9 +135,11 @@ def estimate_cell_types( final_expression_rates, ) = ([], [], [], []) prev_expression_profiles = np.zeros_like(cur_expression_profiles) - for idx, n_components in enumerate(range(min_components, max_components + 1)): - # Print a notification of which component we're on - print(f"K={n_components}") + for idx, n_components in enumerate( + tqdm.trange( + min_components, max_components + 1, desc="estimate_cell_types", position=0 + ) + ): # Warm start from the previous iteration if warm_start and n_components > min_components: @@ -159,8 +161,8 @@ def estimate_cell_types( cur_prior_probs = next_prior_probs cur_expression_profiles = next_expression_profiles - print("priors:", cur_prior_probs) - print("expression:", cur_expression_profiles) + logger.debug("priors:", cur_prior_probs) + logger.debug("expression:", cur_expression_profiles) else: # Cold start from a random location cur_expression_profiles = np.random.dirichlet( @@ -169,9 +171,12 @@ def estimate_cell_types( cur_prior_probs = np.ones(n_components) / n_components converge = tol + 1 - for step in range(max_em_steps): - # Print a notification of which step we are on - print(f"\tStep {step+1}/{max_em_steps}") + for step in tqdm.trange( + max_em_steps, + desc=f"EM for n_components {n_components}", + unit="step", + position=1, + ): # E-step: estimate cell type assignments (posterior probabilities) logits = np.log(cur_prior_probs[None]) + ( @@ -192,16 +197,16 @@ def estimate_cell_types( cur_prior_probs = (cur_cell_types.sum(axis=0) + 1) / ( cur_cell_types.shape[0] + n_components ) - print(cur_prior_probs) + logger.debug(f"cur_prior_probs: {cur_prior_probs}") # Track convergence of the cell type profiles converge = np.linalg.norm( prev_expression_profiles - cur_expression_profiles ) / np.linalg.norm(prev_expression_profiles) - print(f"\t\tConvergence: {converge:.4f}") + logger.debug(f"Convergence: {converge:.4f}") if converge <= tol: - print(f"\t\tStopping early.") + logger.debug(f"Stopping early.") break # Save the results @@ -214,10 +219,9 @@ def estimate_cell_types( gene_counts, cur_expression_profiles, cur_prior_probs ) - print(f"K={n_components}") - print(f"AIC: {aic_scores[idx]:.4f}") - print(f"BIC: {bic_scores[idx]:.4f}") - print() + logger.debug(f"K={n_components}") + logger.debug(f"AIC: {aic_scores[idx]:.4f}") + logger.debug(f"BIC: {bic_scores[idx]:.4f}") return { "bic": bic_scores, @@ -271,9 +275,9 @@ def calculate_pixel_loglikes( transcript_counts = np.zeros( tx_count_grid.shape + (max_transcript_count,), dtype=int ) - for row_idx, row in tx_geo_df.iterrows(): - if row_idx % 1000 == 0: - print(row_idx) + for row_idx, row in tqdm.tqdm( + tx_geo_df.iterrows(), desc="calculate_pixel_loglikes" + ): x, y = int(row["x_location"]), int(row["y_location"]) tx_mask = transcript_ids[x, y] == row["gene_id"] if tx_mask.sum() != 0: @@ -302,14 +306,16 @@ def calculate_pixel_loglikes( # NOTE: yes we could vectorize this, but that leads to memory issues. expression_loglikes = np.zeros_like(rate_loglikes) log_expression = np.log(expression_profiles) + total = expression_loglikes.shape[0] * expression_loglikes.shape[1] + pbar = tqdm.tqdm(total=total, desc="calculate_pixel_loglikes") + for i in range(expression_loglikes.shape[0]): - if i % 100 == 0: - print(i) for j in range(expression_loglikes.shape[1]): expression_loglikes[i, j] = ( transcript_counts[i, j, :, None] * log_expression.T[transcript_ids[i, j]] ).sum(axis=0) + pbar.update(1) # Add the two log-likelihoods together to get the pixelwise log-likelihood return rate_loglikes, expression_loglikes @@ -678,7 +684,7 @@ def spatial_as_sparse_arrays( # Track how many of each class label this tile has temp = np.zeros(n_classes + 2, dtype=int) uniques = np.unique(classes_local, return_counts=True) - print(uniques) + logger.debug(f"uniques: {uniques}") temp[uniques[0] + 1] = uniques[1] class_local_counts.append(temp)