Skip to content

Commit

Permalink
add export to checkpoint, load from checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickcleeve2 committed Dec 28, 2023
1 parent 6309155 commit b49bdb8
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 1 deletion.
48 changes: 48 additions & 0 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,54 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
print('Using torch.compile')
self.network = torch.compile(self.network)

def load_from_checkpoint(self, checkpoint_path: str):
"""Load model from single checkpoint"""

# load the full checkpoint
model_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# load the dataset and plans
dataset_json = model_checkpoint['dataset']
plans_json = model_checkpoint['plans']
plans_manager = PlansManager(plans_json)

# load the model parameters
parameters = []
checkpoint_name = "final" # always use final checkpoint for now
for i, k in enumerate(sorted(model_checkpoint['folds'])):

checkpoint = model_checkpoint['folds'][k][checkpoint_name]

if i == 0: # use first fold to get trainer and configuration name
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None

parameters.append(checkpoint['network_weights'])

configuration_manager = plans_manager.get_configuration(configuration_name)

# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager,
num_input_channels, enable_deep_supervision=False)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)


def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],
dataset_json: dict, trainer_name: str,
Expand Down
18 changes: 17 additions & 1 deletion nnunetv2/model_sharing/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nnunetv2.model_sharing.model_download import download_and_install_from_url
from nnunetv2.model_sharing.model_export import export_pretrained_model
from nnunetv2.model_sharing.model_export import export_pretrained_model, export_model_checkpoint
from nnunetv2.model_sharing.model_import import install_model_from_zip_file


Expand Down Expand Up @@ -59,3 +59,19 @@ def export_pretrained_model_entry():
export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
export_crossval_predictions=args.exp_cv_preds)


def export_model_to_checkpoint():
import argparse
parser = argparse.ArgumentParser(description="Export nnunet model checkpoint as a single .pth file")
parser.add_argument("--path", type=str, help="path to nnunet model directory")
parser.add_argument("--checkpoint_path", type=str, help="path to save the checkpoint", required=False, default=None)
parser.add_argument("--checkpoint_name", type=str, help="name of the checkpoint", required=False, default="model_checkpoint.pth",)

args = parser.parse_args()

export_model_checkpoint(
path=args.path,
checkpoint_path=args.checkpoint_path,
checkpoint_name=args.checkpoint_name,
)
83 changes: 83 additions & 0 deletions nnunetv2/model_sharing/model_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import zipfile

import torch
from nnunetv2.utilities.file_path_utilities import *


Expand Down Expand Up @@ -119,6 +121,87 @@ def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: st
zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results))
print('Done')

def export_model_checkpoint(
path: str,
checkpoint_path: str = None,
checkpoint_name: str = "model_checkpoint.pth",
) -> None:
"""Save NNUNet model checkpoint as a single .pth file
args:
path: path to the nnunet model directory
"""
# nnunet model directory structure for ensemble:
# model
# dataset.json
# plans.json
# fold_n:
# checkpoint_best.pth
# checkpoint_final.pth

# we want to convert it to a single .pth file with the following structure:
# model_checkpoint.pth
# dataset: dataset.json
# plans: plans.json
# fold_n:
# best: checkpoint_best.pth
# final: checkpoint_final.pth

# this makes it more portable and easier to load

def load_json(path: str):
with open(path, "r") as f:
return json.load(f)

# confirm that the path is a nnunet model directory
if not os.path.isdir(path):
raise ValueError(f"{path} is not a directory")
if not os.path.exists(os.path.join(path, "dataset.json")):
raise ValueError(f"{path} does not contain a dataset.json file")
if not os.path.exists(os.path.join(path, "plans.json")):
raise ValueError(f"{path} does not contain a plans.json file")

print(f"Exporting model checkpoint from {path}...")

model_checkpoint = {}

# paths
dataset_json_path = os.path.join(path, "dataset.json")
plan_json_path = os.path.join(path, "plans.json")

# load the dataset and plans
print("Loading dataset and plans configurations...")
model_checkpoint["dataset"] = load_json(dataset_json_path)
model_checkpoint["plans"] = load_json(plan_json_path)

# load the folds
model_checkpoint["folds"] = {}

# get all the fold directories,
fold_dirs = sorted(glob.glob(os.path.join(path, "fold_*")))
print(f"Found {len(fold_dirs)} folds...")
for fold_dir in fold_dirs:
fold_name = os.path.basename(fold_dir)
print(f"Processing fold {fold_name}...")

# load the best/ final checkpoint
best_checkpoint_path = os.path.join(fold_dir, "checkpoint_best.pth")
final_checkpoint_path = os.path.join(fold_dir, "checkpoint_final.pth")

model_checkpoint["folds"][fold_name] = {
"best": torch.load(best_checkpoint_path, map_location=torch.device("cpu")),
"final": torch.load(
final_checkpoint_path, map_location=torch.device("cpu")
),
}

# save as single torch checkpoint
if checkpoint_path is None:
checkpoint_path = os.path.join(path, checkpoint_name)
torch.save(model_checkpoint, checkpoint_path)
print(f"Exported model checkpoint to {checkpoint_path}")



if __name__ == '__main__':
export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, ))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_gener
nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url"
nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point"
nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry"
nnUNetv2_export_model_to_checkpoint = "nnunetv2.model_sharing.entry_points:export_model_to_checkpoint"
nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets"
nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point"
nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point"
Expand Down

0 comments on commit b49bdb8

Please sign in to comment.