Skip to content

Commit

Permalink
actually save the model
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffquinn-msk committed Feb 9, 2024
1 parent 96da651 commit 052f96b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
17 changes: 17 additions & 0 deletions src/nuc2seg/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import logging
import numpy as np
import torch

from nuc2seg import log_config
from nuc2seg.train import train
Expand All @@ -20,6 +21,12 @@ def get_parser():
type=str,
required=True,
)
parser.add_argument(
"--model-weights-output",
help="File to save model weights to.",
type=str,
required=True,
)
parser.add_argument(
"--n-classes",
help="Number of classes to segment.",
Expand Down Expand Up @@ -86,6 +93,12 @@ def get_parser():
type=float,
default=1.0,
)
parser.add_argument(
"--validation-frequency",
help="Frequency of validation.",
type=int,
default=500,
)
parser.add_argument(
"--max-workers",
help="Maximum number of workers to use for data loading.",
Expand Down Expand Up @@ -133,4 +146,8 @@ def main():
momentum=args.momentum,
gradient_clipping=args.gradient_clipping,
max_workers=args.max_workers,
validation_frequency=args.validation_frequency,
)

logger.info(f"Saving model weights to {args.model_weights_output}")
torch.save(model.state_dict(), args.model_weights_output)
4 changes: 3 additions & 1 deletion src/nuc2seg/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def evaluate(net, dataloader, device, amp):
# iterate over the validation set
# with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
# for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
for idx, batch in enumerate(tqdm.tqdm(dataloader, desc="Validation", unit="batch")):
for idx, batch in enumerate(
tqdm.tqdm(dataloader, desc="Validation", unit="batch", position=3)
):
x, y, z, labels, label_mask = (
batch["X"],
batch["Y"],
Expand Down
3 changes: 2 additions & 1 deletion src/nuc2seg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def train(
momentum: float = 0.999,
gradient_clipping: float = 1.0,
max_workers: int = 1,
validation_frequency: int = 500,
):
# Create the dataset
dataset = XeniumDataset(tiles_dir)
Expand Down Expand Up @@ -155,7 +156,7 @@ def train(
epoch_loss += loss.item()

# for Evaluating model performance/ convergence
if global_step % 500 == 0:
if global_step % validation_frequency == 0:
validation_score = evaluate(model, val_loader, device, amp)
validation_scores.append(validation_score)

Expand Down

0 comments on commit 052f96b

Please sign in to comment.