From 2f9c10ed4ab1554077ab119fee8f01c71c8233f4 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 27 Aug 2024 16:32:31 -0400 Subject: [PATCH 01/15] Fix point bugs and finetuning issue Signed-off-by: heyufan1995 --- models/vista3d/configs/train.json | 7 +++--- models/vista3d/configs/train_continual.json | 20 +++++++--------- models/vista3d/docs/README.md | 26 +++++++++++++++------ models/vista3d/scripts/evaluator.py | 15 ++++++------ models/vista3d/scripts/inferer.py | 4 +++- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/models/vista3d/configs/train.json b/models/vista3d/configs/train.json index f6fc6bbb..9b80f843 100644 --- a/models/vista3d/configs/train.json +++ b/models/vista3d/configs/train.json @@ -16,7 +16,7 @@ "early_stop": false, "fold": 0, "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", - "epochs": 100, + "epochs": 5, "val_interval": 1, "val_at_start": false, "sw_overlap": 0.625, @@ -28,8 +28,8 @@ "max_prompt": null, "max_backprompt": null, "max_foreprompt": null, - "drop_label_prob": 0.5, - "drop_point_prob": 0.5, + "drop_label_prob": 0.25, + "drop_point_prob": 0.25, "exclude_background": true, "use_cfp": true, "label_set": null, @@ -379,6 +379,7 @@ "exclude_background": "@exclude_background", "use_cfp": "@use_cfp", "label_set": "@label_set", + "val_head": "auto", "user_prompt": false } } diff --git a/models/vista3d/configs/train_continual.json b/models/vista3d/configs/train_continual.json index 3d9651ff..ce8a85b3 100644 --- a/models/vista3d/configs/train_continual.json +++ b/models/vista3d/configs/train_continual.json @@ -6,8 +6,8 @@ "finetune_model_path": "$@bundle_root + '/models/model.pt'", "n_train_samples": 10, "n_val_samples": 10, - "val_interval": 40, - "learning_rate": 0.0001, + "val_interval": 1, + "learning_rate": 0.00005, "lr_schedule#activate": false, "loss#smooth_dr": 0.01, "loss#smooth_nr": 0.0001, @@ -18,18 +18,14 @@ "default": [ [ 1, - 2 - ], - [ - 2, - 254 + 3 ] ] }, "patch_size": [ - 160, - 160, - 160 + 128, + 128, + 128 ], "label_set": "$[0] + list(x[1] for x in @label_mappings#default)", "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)", @@ -99,11 +95,13 @@ "num_workers": "@num_cache_workers", "progress": "@show_cache_progress" }, + "validate#evaluator#hyper_kwargs#val_label_set": "$list(range(len(@val_label_set)))", "validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]", "valid_remap": { "_target_": "monai.apps.vista3d.transforms.Relabeld", "keys": "label", "label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}", "dtype": "$torch.uint8" - } + }, + "validate#handlers#3#key_metric_filename": "model_finetune.pt" } diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index 966dba99..fd600ae2 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -90,10 +90,9 @@ torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config #### Execute continual learning When finetuning with new class names, please update `configs/train_continual.json`'s `label_mappings` accordingly. -The current label mapping `[[1, 2], [2, 254]]` indicates that training labels' class indices `1` and `2`, are mapped -to the VISTA model's class `2` and `254` respectively (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). -Since `254` is not used by VISTA, it is therefore indicating -training with a new class (the training label's class `2` will be trained as VISTA class `254`). +The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped +to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). For new classes, user +can map to any value larger than 132. `label_set` is used to identify the VISTA model classes for providing training prompts. `val_label_set` is used to identify the original training label classes for computing foreground/background mask during validation. @@ -103,7 +102,10 @@ The default configs for both variables are derived from the `label_mappings` con "label_set": "$[0] + list(x[1] for x in @label_mappings#default)" "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)" ``` - +`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob`=1, the +model is only finetuning for automatic segmentation, while `drop_label_prob`=1 means only finetuning for interactive segmentation. The VISTA3D foundation +model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). +In this bundle, the training is simplified by jointly training with class prompts and point prompts. Single-GPU: ``` @@ -117,11 +119,21 @@ torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ --config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.005 ``` -The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [160, 160, 160]`, and this works for the use cases +The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`, and this works for the use cases of extending the current model to segment a few novel classes. Finetuning all supported classes may require large GPU memory and carefully designed multi-stage training processes. -Changing `patch_size` to a smaller value such as `"patch_size": [128, 128, 128]` used in `configs/train.json` would reduce the training memory footprint. +Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` used in `configs/train.json` would reduce the training memory footprint. + +In `train_continual.json`, only subset of training and validation data are used, change `n_train_samples` and `n_val_samples` to use full dataset. + +In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`, +the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to +speed issue. + +Note: `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is +required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped +to the VISTA3D global class index. #### Execute evaluation `n_train_samples` and `n_val_samples` are used to specify the number of samples to use for training and validation respectively. diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py index 6b1d4a89..59e34298 100644 --- a/models/vista3d/scripts/evaluator.py +++ b/models/vista3d/scripts/evaluator.py @@ -207,6 +207,8 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten if batchdata is None: raise ValueError("Must provide batch data for current iteration.") label_set = engine.hyper_kwargs.get("label_set", None) + # this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points + val_label_set = engine.hyper_kwargs.get("val_label_set", label_set) # If user provide prompts in the inference, input image must contain original affine. # the point coordinates are from the original_affine space, while image here is after preprocess transforms. if engine.hyper_kwargs["user_prompt"]: @@ -242,18 +244,17 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten output_classes = engine.hyper_kwargs["output_classes"] label_set = np.arange(output_classes).tolist() label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1) - # point prompt is generated withing vista3d,provide empty points + # point prompt is generated withing vista3d, provide empty points points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device) point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device) - if engine.hyper_kwargs["drop_point_prob"] > 0.99: + # validation for either auto or point. + if engine.hyper_kwargs.get("val_head", "auto") == 'auto': # automatic only validation - points = None - point_labels = None - if engine.hyper_kwargs["drop_label_prob"] > 0.99: + # remove val_label_set, vista3d will not sample points from gt labels. + val_label_set = None + else: # point only validation label_prompt = None - # this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points - val_label_set = engine.hyper_kwargs.get("val_label_set", label_set) # put iteration outputs into engine.state engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels} diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index c30ba386..b7c9bc43 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -25,7 +25,7 @@ class Vista3dInferer(Inferer): Args: roi_size: the sliding window patch size. overlap: sliding window overlap ratio. - use_cfp: use class prompt for point head. + use_cfp: use class prompt for point head. Deprecated. """ def __init__(self, roi_size, overlap, use_cfp, use_point_window=False, sw_batch_size=1) -> None: @@ -91,6 +91,7 @@ def __call__( roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, transpose=True, + with_coord=True, predictor=network, mode="gaussian", sw_device=device, @@ -113,6 +114,7 @@ def __call__( roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, transpose=True, + with_coord=True, predictor=network, mode="gaussian", sw_device=device, From d214d0642c613466dc5290aa88e93996128f607a Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Tue, 27 Aug 2024 16:38:53 -0400 Subject: [PATCH 02/15] fixes racing condition when InvertD is used along with ThreadDataLoader Signed-off-by: heyufan1995 --- models/vista3d/configs/inference.json | 3 ++- models/vista3d/scripts/evaluator.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json index fd158695..6e654a78 100644 --- a/models/vista3d/configs/inference.json +++ b/models/vista3d/configs/inference.json @@ -4,6 +4,7 @@ "$import os", "$import scripts", "$import numpy as np", + "$import copy", "$import json" ], "bundle_root": "./", @@ -146,7 +147,7 @@ { "_target_": "Invertd", "keys": "pred", - "transform": "@preprocessing", + "transform": "$copy.deepcopy(@preprocessing)", "orig_keys": "@image_key", "nearest_interp": true, "to_tensor": true diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py index 59e34298..21aec86a 100644 --- a/models/vista3d/scripts/evaluator.py +++ b/models/vista3d/scripts/evaluator.py @@ -19,7 +19,7 @@ from monai.engines.evaluator import SupervisedEvaluator from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch from monai.inferers import Inferer, SimpleInferer -from monai.transforms import Transform +from monai.transforms import Transform, reset_ops_id from monai.utils import ForwardMode, RankFilter, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from torch.utils.data import DataLoader @@ -281,6 +281,7 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten labels=labels, label_set=val_label_set, ) + inputs = reset_ops_id(inputs) # Add dim 0 for decollate batch engine.state.output["label_prompt"] = label_prompt.unsqueeze(0) if label_prompt is not None else None engine.state.output["points"] = points.unsqueeze(0) if points is not None else None From 14e238592b5c3097d9de2f70ac75cee82d5f125d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 20:44:39 +0000 Subject: [PATCH 03/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/vista3d/configs/train_continual.json | 6 +++--- models/vista3d/docs/README.md | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/models/vista3d/configs/train_continual.json b/models/vista3d/configs/train_continual.json index ce8a85b3..b700a178 100644 --- a/models/vista3d/configs/train_continual.json +++ b/models/vista3d/configs/train_continual.json @@ -7,7 +7,7 @@ "n_train_samples": 10, "n_val_samples": 10, "val_interval": 1, - "learning_rate": 0.00005, + "learning_rate": 5e-05, "lr_schedule#activate": false, "loss#smooth_dr": 0.01, "loss#smooth_nr": 0.0001, @@ -95,7 +95,7 @@ "num_workers": "@num_cache_workers", "progress": "@show_cache_progress" }, - "validate#evaluator#hyper_kwargs#val_label_set": "$list(range(len(@val_label_set)))", + "validate#evaluator#hyper_kwargs#val_label_set": "$list(range(len(@val_label_set)))", "validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]", "valid_remap": { "_target_": "monai.apps.vista3d.transforms.Relabeld", @@ -103,5 +103,5 @@ "label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}", "dtype": "$torch.uint8" }, - "validate#handlers#3#key_metric_filename": "model_finetune.pt" + "validate#handlers#3#key_metric_filename": "model_finetune.pt" } diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index fd600ae2..39268acb 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -105,7 +105,7 @@ The default configs for both variables are derived from the `label_mappings` con `drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob`=1, the model is only finetuning for automatic segmentation, while `drop_label_prob`=1 means only finetuning for interactive segmentation. The VISTA3D foundation model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). -In this bundle, the training is simplified by jointly training with class prompts and point prompts. +In this bundle, the training is simplified by jointly training with class prompts and point prompts. Single-GPU: ``` @@ -125,15 +125,15 @@ multi-stage training processes. Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` used in `configs/train.json` would reduce the training memory footprint. -In `train_continual.json`, only subset of training and validation data are used, change `n_train_samples` and `n_val_samples` to use full dataset. +In `train_continual.json`, only subset of training and validation data are used, change `n_train_samples` and `n_val_samples` to use full dataset. In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`, -the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to -speed issue. +the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to +speed issue. Note: `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped -to the VISTA3D global class index. +to the VISTA3D global class index. #### Execute evaluation `n_train_samples` and `n_val_samples` are used to specify the number of samples to use for training and validation respectively. From a9e030c94a740a115828302d5142b02c4669537b Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Thu, 29 Aug 2024 10:38:03 -0400 Subject: [PATCH 04/15] Fix comments Signed-off-by: heyufan1995 --- models/vista3d/configs/inference.json | 3 --- models/vista3d/configs/train.json | 4 ---- models/vista3d/docs/README.md | 4 ++-- models/vista3d/scripts/inferer.py | 6 +----- models/vista3d/scripts/trainer.py | 3 +-- 5 files changed, 4 insertions(+), 16 deletions(-) diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json index 6e654a78..ada91042 100644 --- a/models/vista3d/configs/inference.json +++ b/models/vista3d/configs/inference.json @@ -48,7 +48,6 @@ 128 ], "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", - "use_cfp": false, "use_point_window": true, "network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)", "network": "$@network_def.to(@device)", @@ -128,7 +127,6 @@ "roi_size": "@patch_size", "overlap": 0.5, "sw_batch_size": "@sw_batch_size", - "use_cfp": "@use_cfp", "use_point_window": "@use_point_window" }, "postprocessing": { @@ -193,7 +191,6 @@ "val_handlers": "@handlers", "amp": true, "hyper_kwargs": { - "use_cfp": "@use_cfp", "user_prompt": true, "everything_labels": "@everything_labels" } diff --git a/models/vista3d/configs/train.json b/models/vista3d/configs/train.json index 9b80f843..aca8cb9a 100644 --- a/models/vista3d/configs/train.json +++ b/models/vista3d/configs/train.json @@ -31,7 +31,6 @@ "drop_label_prob": 0.25, "drop_point_prob": 0.25, "exclude_background": true, - "use_cfp": true, "label_set": null, "val_label_set": "@label_set", "amp": true, @@ -277,7 +276,6 @@ "drop_label_prob": "@drop_label_prob", "drop_point_prob": "@drop_point_prob", "exclude_background": "@exclude_background", - "use_cfp": "@use_cfp", "label_set": "@label_set", "patch_size": "@patch_size", "user_prompt": false @@ -315,7 +313,6 @@ "_target_": "scripts.inferer.Vista3dInferer", "roi_size": "@patch_size_valid", "overlap": "@sw_overlap", - "use_cfp": "@use_cfp" }, "handlers": [ { @@ -377,7 +374,6 @@ "drop_label_prob": "@drop_label_prob", "drop_point_prob": "@drop_point_prob", "exclude_background": "@exclude_background", - "use_cfp": "@use_cfp", "label_set": "@label_set", "val_head": "auto", "user_prompt": false diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index fd600ae2..14f2976b 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -102,8 +102,8 @@ The default configs for both variables are derived from the `label_mappings` con "label_set": "$[0] + list(x[1] for x in @label_mappings#default)" "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)" ``` -`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob`=1, the -model is only finetuning for automatic segmentation, while `drop_label_prob`=1 means only finetuning for interactive segmentation. The VISTA3D foundation +`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob=1`, the +model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). In this bundle, the training is simplified by jointly training with class prompts and point prompts. diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index b7c9bc43..25f48525 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -25,14 +25,12 @@ class Vista3dInferer(Inferer): Args: roi_size: the sliding window patch size. overlap: sliding window overlap ratio. - use_cfp: use class prompt for point head. Deprecated. """ - def __init__(self, roi_size, overlap, use_cfp, use_point_window=False, sw_batch_size=1) -> None: + def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -> None: Inferer.__init__(self) self.roi_size = roi_size self.overlap = overlap - self.use_cfp = use_cfp self.sw_batch_size = sw_batch_size self.use_point_window = use_point_window self.sliding_window_inferer = point_based_window_inferer if use_point_window else sliding_window_inference @@ -104,7 +102,6 @@ def __call__( prev_mask=prev_mask, labels=labels, label_set=label_set, - use_cfp=self.use_cfp, ) except Exception: val_outputs = None @@ -127,6 +124,5 @@ def __call__( prev_mask=prev_mask, labels=labels, label_set=label_set, - use_cfp=self.use_cfp, ) return val_outputs diff --git a/models/vista3d/scripts/trainer.py b/models/vista3d/scripts/trainer.py index 7a559afc..e96daf9c 100644 --- a/models/vista3d/scripts/trainer.py +++ b/models/vista3d/scripts/trainer.py @@ -182,8 +182,7 @@ def _compute_pred_loss(): input_images=inputs, point_coords=point, point_labels=point_label, - class_vector=label_prompt, - use_cfp=engine.hyper_kwargs["use_cfp"], + class_vector=label_prompt ) # engine.state.output[Keys.PRED] = outputs engine.fire_event(IterationEvents.FORWARD_COMPLETED) From 9d13d20372ff8f3553b456e39d611c130affdb21 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Mon, 16 Sep 2024 16:15:51 -0400 Subject: [PATCH 05/15] Small fix Signed-off-by: heyufan1995 --- models/vista3d/scripts/trainer.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/models/vista3d/scripts/trainer.py b/models/vista3d/scripts/trainer.py index c69f631b..701c6894 100644 --- a/models/vista3d/scripts/trainer.py +++ b/models/vista3d/scripts/trainer.py @@ -179,14 +179,7 @@ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]): def _compute_pred_loss(): outputs = engine.network( -<<<<<<< HEAD - input_images=inputs, - point_coords=point, - point_labels=point_label, - class_vector=label_prompt -======= input_images=inputs, point_coords=point, point_labels=point_label, class_vector=label_prompt ->>>>>>> 6bdfd30b63b3d1a799c80f1d17a783ff3a66c66c ) # engine.state.output[Keys.PRED] = outputs engine.fire_event(IterationEvents.FORWARD_COMPLETED) From 25d10fca84e0033a572ef8f0ff03421f6e807497 Mon Sep 17 00:00:00 2001 From: heyufan1995 Date: Wed, 18 Sep 2024 12:53:29 -0400 Subject: [PATCH 06/15] Resolve oom and update readme Signed-off-by: heyufan1995 --- models/vista3d/configs/train_continual.json | 2 +- models/vista3d/docs/README.md | 306 ++++++++++---------- models/vista3d/docs/labels.json | 5 +- models/vista3d/scripts/inferer.py | 43 +-- 4 files changed, 177 insertions(+), 179 deletions(-) diff --git a/models/vista3d/configs/train_continual.json b/models/vista3d/configs/train_continual.json index b700a178..ab60063a 100644 --- a/models/vista3d/configs/train_continual.json +++ b/models/vista3d/configs/train_continual.json @@ -32,7 +32,7 @@ "num_classes": 255, "output_classes": "$len(@label_set)", "optimizer": { - "_target_": "Novograd", + "_target_": "torch.optim.AdamW", "lr": "@learning_rate", "params": "$@network.parameters()" }, diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index 0f1481ec..798faef8 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -1,22 +1,17 @@ # Model Overview -Vista3D model train/inference pipeline +Vista3D model fintuning/evaluation/inference pipeline. VISTA3D is trained using over 20 partial datasets with more complicated pipeline. To avoid confusion, we will only provide finetuning/continual learning APIs for users to finetune on their +own datasets. -## Training configuration -The training was performed with the following: -- GPU: at least 16GB GPU memory -- Actual Model Input: 128 x 128 x 128 -- AMP: True -- Optimizer: Adam -- Learning Rate: 1e-2 -- Loss: BCE loss and L1 loss +## Continual learning -## Data -Note that VISTA3D is trained from a huge collection of datasets and cannot be simply reproduced in this bundle. +For continual learning, user can change `configs/train_continual.json`. More advanced users can change configurations in `configs/train.json`. The hyperparameters in `configs/train_continual.json` will overwrite ones in `configs/train.json`. Most hyperparameters are straighforward and user can tell based on their names. We list hyperparameters that needs to be modified. -The spleen Task from the Medical Segmentation Decathalon is selected as an example to show how to do train, continuous learning and evaluate. Users can find more details on the datasets at http://medicaldecathlon.com/. +### Data -To train with other datasets, users need to provide a json data split for training and continuous learning (`configs/msd_task09_spleen_folds.json` is an example for reference). The data split should meet the following format with a 5-fold split ('testing' labels are optional): -``` +The spleen Task from the Medical Segmentation Decathalon is selected as an example to show how to continuous learning. Users can find more details on the datasets at http://medicaldecathlon.com/. + +To train with other datasets, users need to provide a json data split for training and continuous learning (`configs/msd_task09_spleen_folds.json` is an example for reference). The data split should meet the following format ('testing' labels are optional): +```json { "training": [ {"image": "img0001.nii.gz", "label": "label0001.nii.gz", "fold": 0}, @@ -31,135 +26,112 @@ To train with other datasets, users need to provide a json data split for traini } ``` -### Input -1 channel -- List of 3D CT patches - -### Output -In Training Mode: Training loss - -In Evaluation Mode: Segmentation - -## Performance - -#### TensorRT speedup -The `vista3d` bundle supports acceleration with TensorRT. The table below displays the speedup ratios observed on an A100 80G GPU. Please note for 32bit precision models, they are benchmarked with tf32 weight format. - -| method | torch_tf32(ms) | torch_amp(ms) | trt_tf32(ms) | trt_fp16(ms) | speedup amp | speedup tf32 | speedup fp16 | amp vs fp16| -| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | -| model computation | 108.53| 91.9 | 106.84 | 60.02 | 1.18 | 1.02 | 1.81 | 1.53 | -| end2end | 6740 | 5166 | 5242 | 3386 | 1.30 | 1.29 | 1.99 | 1.53 | - -Where: -- `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing -- `end2end` means run the bundle end-to-end with the TensorRT based model. -- `torch_tf32` and `torch_amp` are for the PyTorch models with or without `amp` mode. -- `trt_tf32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision. -- `speedup amp`, `speedup tf32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model -- `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model. - -This result is benchmarked under: - - TensorRT: 10.3.0+cuda12.6 - - Torch-TensorRT Version: 2.4.0 - - CPU Architecture: x86-64 - - OS: ubuntu 20.04 - - Python version:3.10.12 - - CUDA version: 12.6 - - GPU models and configuration: A100 80G - -## MONAI Bundle Commands -In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file. - -For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html). - - -#### Execute training: - -``` -python -m monai.bundle run --config_file configs/train.json ``` +Note the data is not the absolute path to the image and label file. The actual image file will be `os.path.join(dataset_dir, data["training"][item]["image"])`, where `dataset_dir` is defined in `configs/train_continual.json`. Also 5-fold cross-validation is not required! `fold=0` is defined in train.json, which means any data item with fold==0 will be used as validation and other fold will be used for training. So if you only have 2 data, you can manually set one data to be validation by setting "fold": 0 in its datalist and the other to be training by setting "fold" to any number other than 0. +``` + +### Best practice to generate data list +User can use monai to generate the 5-fold data lists. Full exampls can be found in VISTA3D open source [codebase](https://github.com/Project-MONAI/VISTA/blob/main/vista3d/data/make_datalists.py) +```python +from monai.data.utils import partition_dataset +from monai.bundle import ConfigParser +base_url = "/path_to_your_folder/" +json_name = "./your_5_folds.json" +# create matching image and label lists. +# The code to generate the lists is based on your local data structure. +# You can use glob.glob("**.nii.gz") e.t.c. +image_list = ['images/1.nii.gz', 'images/2.nii.gz', ...] +label_list = ['labels/1.nii.gz', 'labels/2.nii.gz', ...] +items = [{"image": img, "label": lab} for img, lab in zip(image_list, label_list)] +# 80% for training 20% for testing. +train_test = partition_dataset(items, ratios=[0.8, 0.2], shuffle=True, seed=0) +print(f"training: {len(train_test[0])}, testing: {len(train_test[1])}") +# num_partitions-fold split for the training set. +train_val = partition_dataset(train_test[0], num_partitions=5, shuffle=True, seed=0) +print(f"training validation folds sizes: {[len(x) for x in train_val]}") +# add the fold index to each training data. +training = [] +for f, x in enumerate(train_val): + for item in x: + item["fold"] = f + training.append(item) +# save json file +parser = ConfigParser({}) +parser["training"] = training +parser["testing"] = train_test[1] +print(f"writing {json_name}\n\n") +if os.path.exists(json_name): + logger.warning(f"rewrite existing datalist file: {json_name}") +ConfigParser.export_config_file(parser.config, json_name, indent=4) +``` + +### Configurations + +#### `label_mappings` +The core concept of label_mapping is to convert ground-truth label index of each dataset to a unified class index. For example, "Spleen" in MSD09 groundtruth will be represented by 1, while in AbdomenCT-1K it's 3. We unified a global label index [`label_dict`](./labels.json) to represent all 132 classes, and create a label mapping to map those local index to this global index. So when a user is training on their own dataset, we need to know this mapping. -Please note that if the default dataset path is not modified with the actual path in the bundle config files, you can also override it by using `--dataset_dir`: +The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped +to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). So during inference, "3" is used to segment spleen. -``` -python -m monai.bundle run --config_file configs/train.json --dataset_dir -``` +Since it's finetuning, you can map your local class to any global class. If you use [[1,4]], where "4" represents pancreas, the finetuning can still work but requires more training data and epoch because the class "4" is already assigned and trained with pancreas. If you use [[1,3]], where "3" already represents spleen, the finetuning will converge much faster. -#### Execute finetune: +#### Best practice to set label_mapping +For a class that represent the same or similar class as the global index, directly map it to the global index. For example, "mouse left lung" (e.g. index 2 in the mouse dataset) can be mapped to the 28 "left lung upper lobe"(or 29 "left lung lower lobe") with [[2,28]]. After finetuning, 28 now represents "mouse left lung" and will be used for segmentation. If you want to segment 4 substructures of aorta, you can map one of the substructuress to 6 aorta and the rest to any unused classes (class > 132), [[1,6],[2,133],[3,134],[4,135]]. For a completely novel class that none of the VISTA global classes are related, directly map to unused classes (class > 132). ``` -python -m monai.bundle run --config_file configs/train.json --finetune True --epochs 5 +NOTE: Do not map to global index value >= 255. `num_classes=255` in the config only represent the maximum mapping index, while the actual output class number only depends on your label_mapping definition. The 255 value in the inference output is also used to represent 'NaN' value. ``` +#### `n_train_samples` and `n_val_samples` +In `train_continual.json`, only `n_train_samples` and `n_val_samples` are used for training and validation. Remember to change these two values. -Please note that the path of model weights is "/models/model.pt", you can also override it by using `--finetune_model_path`: +#### `patch_size` +The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`. For finetuning purposes, this value needs to be changed acccording to user's task and GPU memory. Usually a larger patch_size will give better final results. -``` -python -m monai.bundle run --config_file configs/train.json --finetune True --finetune_model_path -``` - -#### Enable early stop in training: - -``` -python -m monai.bundle run --config_file configs/train.json --early_stop True -``` - -#### Override the `train` config to execute multi-GPU training: +#### `resample_to_spacing` +The resample_to_spacing parameter is defined in `configs/train_continual.json` and it represents the resolution the model will be trained on. The `1.5,1.5,1.5` mm default is suitable for large CT organs, but for other tasks, this value should be changed to achive the optimal performance. +#### Advanced user: `drop_label_prob` and `drop_point_prob` (in train.json) +VISTA3D is trained to perform both automatic (class prompts) and interactive point segmentation. +`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts during training respectively. If `drop_point_prob=1`, the +model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation +model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). +In this bundle, the training is simplified by jointly training with class prompts and point prompts and both of the drop ratio is set to 0.25. ``` -torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.json','configs/multi_gpu_train.json']" +NOTE: If user doesn't use interactive segmentation, set `drop_point_prob=1` and `drop_label_prob=0` in train.json might provide a faster and easier finetuning process. ``` +#### Other explanatory items +In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`, +the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to +speed issue. - -#### Execute continual learning -When finetuning with new class names, please update `configs/train_continual.json`'s `label_mappings` accordingly. - -The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped -to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). For new classes, user -can map to any value larger than 132. +In `train_continual.json`, `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is +required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped +to the VISTA3D global class index. `label_set` is used to identify the VISTA model classes for providing training prompts. `val_label_set` is used to identify the original training label classes for computing foreground/background mask during validation. - The default configs for both variables are derived from the `label_mappings` config and include `[0]`: ``` "label_set": "$[0] + list(x[1] for x in @label_mappings#default)" "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)" ``` -`drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts respectively. If `drop_point_prob=1`, the -model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation -model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). -In this bundle, the training is simplified by jointly training with class prompts and point prompts. + +### Commands Single-GPU: -``` +```bash python -m monai.bundle run \ - --config_file="['configs/train.json','configs/train_continual.json']" --epochs=320 --learning_rate=0.005 + --config_file="['configs/train.json','configs/train_continual.json']" --epochs=320 --learning_rate=0.00005 ``` Multi-GPU: -``` +```bash torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ - --config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.005 + --config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.00005 ``` -The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`, and this works for the use cases -of extending the current model to segment a few novel classes. Finetuning all supported classes may require large GPU memory and carefully designed -multi-stage training processes. - -Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` used in `configs/train.json` would reduce the training memory footprint. - -In `train_continual.json`, only subset of training and validation data are used, change `n_train_samples` and `n_val_samples` to use full dataset. - -In `train.json`, `validate[evaluator][val_head]` can be `auto` and `point`. If `auto`, the validation results will be automatic segmentation. If `point`, -the validation results will be sampling one positive point per object per patch. The validation scheme of combining auto and point is deprecated due to -speed issue. - -Note: `valid_remap` is a transform that maps the groundtruth label indexes, e.g. [0,2,3,5,6] to sequential and continuous labels [0,1,2,3,4]. This is -required by monai dice calculation. It is not related to mapping label index to VISTA3D defined global class index. The validation data is not mapped -to the VISTA3D global class index. -#### Execute evaluation -`n_train_samples` and `n_val_samples` are used to specify the number of samples to use for training and validation respectively. +## Evaluation `configs/data.yaml` shows potential configurations for each specific dataset for evaluation. @@ -176,28 +148,63 @@ torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ ``` -#### Execute inference: -Notice the VISTA3d bundle requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation. -It also supports point click prompts for binary segmentation. User can provide both prompts at the same time. To segment an image, set the input_dict to -: +## Inference: +For inference, VISTA3d bundle requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation. +It also supports point click prompts for binary interactive segmentation. User can provide both prompts at the same time. + +All the configurations for inference is stored in inference.json, change those parameters: +### `input_dict` +`input_dict` defines the image to segment and the prompt for segmentation. ``` "input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'label_prompt':[1]}]", "input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'points':[[138,245,18], [271,343,27]], 'point_labels':[1,0]}]" ``` -- The input_dict must contain the absolute path to the nii image file, and must contain at least one prompt. The keys are "label_prompt", "points" and "point_labels". -- label_prompt is in the format of [B], points is [1, N, 3], point_labels is [1, N]. B is number of foreground object. **B must be 1 if label_prompt and points are provided together** -- N is number of click points, 3 is x,y,z coordinates **IN THE ORIGINAL IMAGE SPACE**. The inferer only supports SINGLE OBJECT point click segmentatation. -- point_labels 0 means background, 1 means foreground, -1 means ignoring this point. -- label_prompt and points key can be missing, but cannot be missing at the same time. -- points and point_labels must pe provided together. -- The label_prompt can perform multiple foreground object segmentation, e.g. [2,3,4,5] means segment those classes. Point prompts must NOT be provided. -- For segment everything, use label_prompt: list(set([i+1 for i in range(132)]) - set([22, 23, 15, 25, 19, 2, 26, 27, 28, 29, 117])) -- The point prompts for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use point prompts for the sub-categories as defined in the inference.json. +- The input_dict must include the key `image` which contain the absolute path to the nii image file, and includes prompt keys of `label_prompt`, `points` and `point_labels`. +- The `label_prompt` is a list of length `B`, which can perform `B` foreground objects segmentation, e.g. `[2,3,4,5]`. If `B>1`, Point prompts must NOT be provided. +- The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which +matches the `points`. 0 means background, 1 means foreground, -1 means ignoring this point. `points` and `point_labels` must pe provided together and match length. +- **B must be 1 if label_prompt and points are provided together**. The inferer only supports SINGLE OBJECT point click segmentatation. +- If no prompt is provided, the model will use `everything_labels` to segment 118 classes: list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132])). +- The `points` together with `label_prompts` for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use `points` for the sub-categories as defined in the inference.json. + +### `label_prompt` and `label_dict` +The `label_dict` defined in [`labels.json`](../docs/labels.json) has in total 132 classes. However, there are 5 we do not support and we keep them due to legacy issue. So in total +VISTA3D support 127 classes. +``` +"16, # prostate or uterus" since we already have "prostate" class, +"18, # rectum", insufficient data or dataset excluded. +"130, # liver tumor" already have hepatic tumor. +"129, # kidney mass" insufficient data or dataset excluded. +"131, # vertebrae L6", insufficient data or dataset excluded. +``` +These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict. + + +``` +Note: if the finetuning mapped the local user data index to global index "2, 20, 21", remove the `subclass` dict from inference.json since those values defined in `subclass` will trigger the wrong subclass segmentation. +``` + +### `resample_spacing` +The optimal inference resample spacing should be changed according to the task. For monkey data, a high resolution of [1,1,1] showed better automatic inference results. This spacing applies to both automatic and interactive segmentation. For zero-shot interactive segmentation for non-human CTs e.g. mouse CT or even rock/stone CT, using original resolution (set `resample_spacing` to [-1,-1,-1]) may give better interactive results. + +### `use_point_window` +When user click a point, there is no need to perform whole image sliding window inference. Set "use_point_window" to true in the inference.json to enable this function. +A window centered at the clicked points will be used for inference. All values outside of the window will set to be "NaN" unless "prev_mask" is passed to the inferer (255 is used to represent NaN). +If no point click exists, this function will not be used. Notice if "use_point_window" is true and user provided point clicks, there will be obvious cut-off box artefacts. + +### Inference GPU benchmarks +Benchmarks on a 16GB V100 GPU with 400G system cpu memory. +| Volume size at 1.5x1.5x1.5 mm | 333x333x603 | 512x512x512 | 512x512x768 | 1024x1024x512 | 1024x1024x768 | +| :---: | :---: | :---: | :---: | :---: | :---: | +|RunTime| 1m07s | 2m09s | 3m25s| 9m20s| killed | +## Commands +The bundle only provides single-gpu inference. +### Single image inference ``` python -m monai.bundle run --config_file configs/inference.json ``` -#### Execute batch inference for segmenting everything +### Batch inference for segmenting everything ``` python -m monai.bundle run --config_file="['configs/inference.json', 'configs/batch_inference.json']" --input_dir="/data/Task09_Spleen/imagesTr" --output_dir="./eval_task09" ``` @@ -205,40 +212,49 @@ python -m monai.bundle run --config_file="['configs/inference.json', 'configs/ba `configs/batch_inference.json` by default runs the segment everything workflow (classes defined by `everything_labels`) on all (`*.nii.gz`) files in `input_dir`. This default is overridable by changing the input folder `input_dir`, or the input image name suffix `input_suffix`, or directly setting the list of filenames `input_list`. -Set `"postprocessing#transforms#0#_disabled_": false` to move the postprocessing to cpu to reduce the GPU memory footprint. -#### Execute inference with the TensorRT model: +### Execute inference with the TensorRT model: ``` python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']" ``` +### TroubleShoot for Out-of-Memory +- Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` would reduce the training/inference memory footprint. +- Changing `train_dataset_cache_rate` and `val_dataset_cache_rate` to a smaller value like `0.1` can solve the out-of-cpu memory issue when using huge finetuning dataset. +- Set `"postprocessing#transforms#0#_disabled_": false` to move the postprocessing to cpu to reduce the GPU memory footprint. -## Automatic segmentation label prompts : -The mapping between organ name and label prompt is in the [json file](labels.json) -## Fast Point Window Inference: -When user click a point, there is no need to perform whole image sliding window inference. Set "use_point_window" to true in the inference.json to enable this function. -A window centered at the clicked points will be used for inference. All values outside of the window will set to be "NaN" unless "prev_mask" is passed to the inferer. -If no point click exists, this function will not be used. Notice if "use_point_window" is true and user provided point clicks, there will be obvious cut-off box artefacts. +### TensorRT speedup +The `vista3d` bundle supports acceleration with TensorRT. The table below displays the speedup ratios observed on an A100 80G GPU. Please note for 32bit precision models, they are benchmarked with tf32 weight format. -# References -- Roth, H., Farag, A., Turkbey, E. B., Lu, L., Liu, J., & Summers, R. M. (2016). Data From Pancreas-CT (Version 2) [Data set]. The Cancer Imaging Archive. https://doi.org/10.7937/K9/TCIA.2016.tNB1kqBU +| method | torch_tf32(ms) | torch_amp(ms) | trt_tf32(ms) | trt_fp16(ms) | speedup amp | speedup tf32 | speedup fp16 | amp vs fp16| +| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| model computation | 108.53| 91.9 | 106.84 | 60.02 | 1.18 | 1.02 | 1.81 | 1.53 | +| end2end | 6740 | 5166 | 5242 | 3386 | 1.30 | 1.29 | 1.99 | 1.53 | -- J. Ma et al., "AbdomenCT-1K: Is Abdominal Organ Segmentation a Solved Problem?," in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 44, no. 10, pp. 6695-6714, 1 Oct. 2022, doi: 10.1109/TPAMI.2021.3100536. +Where: +- `model computation` means the speedup ratio of model's inference with a random input without preprocessing and postprocessing +- `end2end` means run the bundle end-to-end with the TensorRT based model. +- `torch_tf32` and `torch_amp` are for the PyTorch models with or without `amp` mode. +- `trt_tf32` and `trt_fp16` are for the TensorRT based models converted in corresponding precision. +- `speedup amp`, `speedup tf32` and `speedup fp16` are the speedup ratios of corresponding models versus the PyTorch float32 model +- `amp vs fp16` is the speedup ratio between the PyTorch amp model and the TensorRT float16 based model. -- JI YUANFENG. (2022). Amos: A large-scale abdominal multi-organ benchmark for versatile medical image segmentation [Data set]. Zenodo. https://doi.org/10.5281/zenodo.7155725 +This result is benchmarked under: + - TensorRT: 10.3.0+cuda12.6 + - Torch-TensorRT Version: 2.4.0 + - CPU Architecture: x86-64 + - OS: ubuntu 20.04 + - Python version:3.10.12 + - CUDA version: 12.6 + - GPU models and configuration: A100 80G +# References - Antonelli, M., Reinke, A., Bakas, S. et al. The Medical Segmentation Decathlon. Nat Commun 13, 4128 (2022). https://doi.org/10.1038/s41467-022-30695-9 -- Rister, B., Yi, D., Shivakumar, K. et al. CT-ORG, a new dataset for multiple organ segmentation in computed tomography. Sci Data 7, 381 (2020). https://doi.org/10.1038/s41597-020-00715-8 - -- Jakob Wasserthal. (2022). Dataset with segmentations of 104 important anatomical structures in 1204 CT images (1.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.6802614 - -- Gibson, E., Giganti, F., Hu, Y., Bonmati, E., Bandula, S., Gurusamy, K., Davidson, B., Pereira, S. P., Clarkson, M. J., & Barratt, D. C. (2018). Multi-organ Abdominal CT Reference Standard Segmentations (1.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.1169361 - -- Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge https://www.synapse.org/#!Synapse:syn3193805/wiki/217753 +- VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography. arxiv (2024) https://arxiv.org/abs/2406.05285 # License diff --git a/models/vista3d/docs/labels.json b/models/vista3d/docs/labels.json index dcdd73dd..51f93162 100644 --- a/models/vista3d/docs/labels.json +++ b/models/vista3d/docs/labels.json @@ -130,8 +130,5 @@ "kidney mass": 129, "liver tumor": 130, "vertebrae L6": 131, - "airway": 132, - "FDG-avid lesion": 133, - "lung nodule": 134, - "lumbar spine": 135 + "airway": 132 } diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index 25f48525..3c2d99a9 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -14,10 +14,9 @@ import torch from monai.apps.vista3d.inferer import point_based_window_inferer -from monai.inferers import Inferer, sliding_window_inference +from monai.inferers import Inferer, SlidingWindowInfererAdapt from torch import Tensor - class Vista3dInferer(Inferer): """ Vista3D Inferer @@ -33,7 +32,6 @@ def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) - self.overlap = overlap self.sw_batch_size = sw_batch_size self.use_point_window = use_point_window - self.sliding_window_inferer = point_based_window_inferer if use_point_window else sliding_window_inference def __call__( self, @@ -62,11 +60,6 @@ def __call__( prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! """ - sliding_window_inferer = ( - point_based_window_inferer - if (self.use_point_window and point_coords is not None) - else sliding_window_inference - ) prompt_class = copy.deepcopy(class_vector) if class_vector is not None: # Check if network has attribute 'point_head' directly or within its 'module' @@ -79,12 +72,14 @@ def __call__( if torch.any(class_vector > point_head.last_supported): class_vector = None - if isinstance(inputs, list): - device = inputs[0].device - else: - device = inputs.device - try: - val_outputs = sliding_window_inferer( + val_outputs = None + torch.cuda.empty_cache() + if self.use_point_window and point_coords is not None: + if isinstance(inputs, list): + device = inputs[0].device + else: + device = inputs.device + val_outputs = point_based_window_inferer( inputs=inputs, roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, @@ -102,27 +97,17 @@ def __call__( prev_mask=prev_mask, labels=labels, label_set=label_set, - ) - except Exception: - val_outputs = None - torch.cuda.empty_cache() - val_outputs = sliding_window_inferer( - inputs=inputs, + ) + else: + val_outputs = SlidingWindowInfererAdapt( roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, - transpose=True, - with_coord=True, - predictor=network, - mode="gaussian", - sw_device=device, - device="cpu", - overlap=self.overlap, + with_coord=True)(inputs,network,transpose=True, point_coords=point_coords, point_labels=point_labels, class_vector=class_vector, prompt_class=prompt_class, prev_mask=prev_mask, labels=labels, - label_set=label_set, - ) + label_set=label_set) return val_outputs From d8fe7588621bcb1698a1257bc5ffb2a30e6e1a8b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 16:56:08 +0000 Subject: [PATCH 07/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/vista3d/docs/README.md | 28 ++++++++++++++-------------- models/vista3d/scripts/inferer.py | 4 ++-- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index 798faef8..6ba50472 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -4,7 +4,7 @@ own datasets. ## Continual learning -For continual learning, user can change `configs/train_continual.json`. More advanced users can change configurations in `configs/train.json`. The hyperparameters in `configs/train_continual.json` will overwrite ones in `configs/train.json`. Most hyperparameters are straighforward and user can tell based on their names. We list hyperparameters that needs to be modified. +For continual learning, user can change `configs/train_continual.json`. More advanced users can change configurations in `configs/train.json`. The hyperparameters in `configs/train_continual.json` will overwrite ones in `configs/train.json`. Most hyperparameters are straighforward and user can tell based on their names. We list hyperparameters that needs to be modified. ### Data @@ -37,8 +37,8 @@ from monai.data.utils import partition_dataset from monai.bundle import ConfigParser base_url = "/path_to_your_folder/" json_name = "./your_5_folds.json" -# create matching image and label lists. -# The code to generate the lists is based on your local data structure. +# create matching image and label lists. +# The code to generate the lists is based on your local data structure. # You can use glob.glob("**.nii.gz") e.t.c. image_list = ['images/1.nii.gz', 'images/2.nii.gz', ...] label_list = ['labels/1.nii.gz', 'labels/2.nii.gz', ...] @@ -71,13 +71,13 @@ ConfigParser.export_config_file(parser.config, json_name, indent=4) The core concept of label_mapping is to convert ground-truth label index of each dataset to a unified class index. For example, "Spleen" in MSD09 groundtruth will be represented by 1, while in AbdomenCT-1K it's 3. We unified a global label index [`label_dict`](./labels.json) to represent all 132 classes, and create a label mapping to map those local index to this global index. So when a user is training on their own dataset, we need to know this mapping. The current label mapping `[[1, 3]]` indicates that training labels' class indices `1` is mapped -to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). So during inference, "3" is used to segment spleen. +to the VISTA model's class `3` (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). So during inference, "3" is used to segment spleen. -Since it's finetuning, you can map your local class to any global class. If you use [[1,4]], where "4" represents pancreas, the finetuning can still work but requires more training data and epoch because the class "4" is already assigned and trained with pancreas. If you use [[1,3]], where "3" already represents spleen, the finetuning will converge much faster. +Since it's finetuning, you can map your local class to any global class. If you use [[1,4]], where "4" represents pancreas, the finetuning can still work but requires more training data and epoch because the class "4" is already assigned and trained with pancreas. If you use [[1,3]], where "3" already represents spleen, the finetuning will converge much faster. #### Best practice to set label_mapping -For a class that represent the same or similar class as the global index, directly map it to the global index. For example, "mouse left lung" (e.g. index 2 in the mouse dataset) can be mapped to the 28 "left lung upper lobe"(or 29 "left lung lower lobe") with [[2,28]]. After finetuning, 28 now represents "mouse left lung" and will be used for segmentation. If you want to segment 4 substructures of aorta, you can map one of the substructuress to 6 aorta and the rest to any unused classes (class > 132), [[1,6],[2,133],[3,134],[4,135]]. For a completely novel class that none of the VISTA global classes are related, directly map to unused classes (class > 132). +For a class that represent the same or similar class as the global index, directly map it to the global index. For example, "mouse left lung" (e.g. index 2 in the mouse dataset) can be mapped to the 28 "left lung upper lobe"(or 29 "left lung lower lobe") with [[2,28]]. After finetuning, 28 now represents "mouse left lung" and will be used for segmentation. If you want to segment 4 substructures of aorta, you can map one of the substructuress to 6 aorta and the rest to any unused classes (class > 132), [[1,6],[2,133],[3,134],[4,135]]. For a completely novel class that none of the VISTA global classes are related, directly map to unused classes (class > 132). ``` NOTE: Do not map to global index value >= 255. `num_classes=255` in the config only represent the maximum mapping index, while the actual output class number only depends on your label_mapping definition. The 255 value in the inference output is also used to represent 'NaN' value. ``` @@ -88,14 +88,14 @@ In `train_continual.json`, only `n_train_samples` and `n_val_samples` are used f The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [128, 128, 128]`. For finetuning purposes, this value needs to be changed acccording to user's task and GPU memory. Usually a larger patch_size will give better final results. #### `resample_to_spacing` -The resample_to_spacing parameter is defined in `configs/train_continual.json` and it represents the resolution the model will be trained on. The `1.5,1.5,1.5` mm default is suitable for large CT organs, but for other tasks, this value should be changed to achive the optimal performance. +The resample_to_spacing parameter is defined in `configs/train_continual.json` and it represents the resolution the model will be trained on. The `1.5,1.5,1.5` mm default is suitable for large CT organs, but for other tasks, this value should be changed to achive the optimal performance. #### Advanced user: `drop_label_prob` and `drop_point_prob` (in train.json) -VISTA3D is trained to perform both automatic (class prompts) and interactive point segmentation. +VISTA3D is trained to perform both automatic (class prompts) and interactive point segmentation. `drop_label_prob` and `drop_point_prob` means percentage to remove class prompts and point prompts during training respectively. If `drop_point_prob=1`, the model is only finetuning for automatic segmentation, while `drop_label_prob=1` means only finetuning for interactive segmentation. The VISTA3D foundation model is trained with interactive only (drop_label_prob=1) and then froze the point branch and trained with fully automatic segmentation (`drop_point_prob=1`). -In this bundle, the training is simplified by jointly training with class prompts and point prompts and both of the drop ratio is set to 0.25. +In this bundle, the training is simplified by jointly training with class prompts and point prompts and both of the drop ratio is set to 0.25. ``` NOTE: If user doesn't use interactive segmentation, set `drop_point_prob=1` and `drop_label_prob=0` in train.json might provide a faster and easier finetuning process. ``` @@ -150,7 +150,7 @@ torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ ## Inference: For inference, VISTA3d bundle requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation. -It also supports point click prompts for binary interactive segmentation. User can provide both prompts at the same time. +It also supports point click prompts for binary interactive segmentation. User can provide both prompts at the same time. All the configurations for inference is stored in inference.json, change those parameters: ### `input_dict` @@ -161,7 +161,7 @@ All the configurations for inference is stored in inference.json, change those p ``` - The input_dict must include the key `image` which contain the absolute path to the nii image file, and includes prompt keys of `label_prompt`, `points` and `point_labels`. - The `label_prompt` is a list of length `B`, which can perform `B` foreground objects segmentation, e.g. `[2,3,4,5]`. If `B>1`, Point prompts must NOT be provided. -- The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which +- The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which matches the `points`. 0 means background, 1 means foreground, -1 means ignoring this point. `points` and `point_labels` must pe provided together and match length. - **B must be 1 if label_prompt and points are provided together**. The inferer only supports SINGLE OBJECT point click segmentatation. - If no prompt is provided, the model will use `everything_labels` to segment 118 classes: list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132])). @@ -177,7 +177,7 @@ VISTA3D support 127 classes. "129, # kidney mass" insufficient data or dataset excluded. "131, # vertebrae L6", insufficient data or dataset excluded. ``` -These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict. +These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict. ``` @@ -185,7 +185,7 @@ Note: if the finetuning mapped the local user data index to global index "2, 20, ``` ### `resample_spacing` -The optimal inference resample spacing should be changed according to the task. For monkey data, a high resolution of [1,1,1] showed better automatic inference results. This spacing applies to both automatic and interactive segmentation. For zero-shot interactive segmentation for non-human CTs e.g. mouse CT or even rock/stone CT, using original resolution (set `resample_spacing` to [-1,-1,-1]) may give better interactive results. +The optimal inference resample spacing should be changed according to the task. For monkey data, a high resolution of [1,1,1] showed better automatic inference results. This spacing applies to both automatic and interactive segmentation. For zero-shot interactive segmentation for non-human CTs e.g. mouse CT or even rock/stone CT, using original resolution (set `resample_spacing` to [-1,-1,-1]) may give better interactive results. ### `use_point_window` When user click a point, there is no need to perform whole image sliding window inference. Set "use_point_window" to true in the inference.json to enable this function. @@ -193,7 +193,7 @@ A window centered at the clicked points will be used for inference. All values o If no point click exists, this function will not be used. Notice if "use_point_window" is true and user provided point clicks, there will be obvious cut-off box artefacts. ### Inference GPU benchmarks -Benchmarks on a 16GB V100 GPU with 400G system cpu memory. +Benchmarks on a 16GB V100 GPU with 400G system cpu memory. | Volume size at 1.5x1.5x1.5 mm | 333x333x603 | 512x512x512 | 512x512x768 | 1024x1024x512 | 1024x1024x768 | | :---: | :---: | :---: | :---: | :---: | :---: | |RunTime| 1m07s | 2m09s | 3m25s| 9m20s| killed | diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index 3c2d99a9..57eaac0e 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -97,12 +97,12 @@ def __call__( prev_mask=prev_mask, labels=labels, label_set=label_set, - ) + ) else: val_outputs = SlidingWindowInfererAdapt( roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, - with_coord=True)(inputs,network,transpose=True, + with_coord=True)(inputs,network,transpose=True, point_coords=point_coords, point_labels=point_labels, class_vector=class_vector, From d13bdaf081a011d933ac1cde88858bbbc8497caa Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 08:52:29 +0800 Subject: [PATCH 08/15] fix ci Signed-off-by: Yiheng Wang --- models/vista3d/configs/metadata.json | 3 ++- models/vista3d/scripts/inferer.py | 12 ++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/models/vista3d/configs/metadata.json b/models/vista3d/configs/metadata.json index be73909a..dcf174b3 100644 --- a/models/vista3d/configs/metadata.json +++ b/models/vista3d/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json", - "version": "0.4.8", + "version": "0.4.9", "changelog": { + "0.4.9": "fix oom issue and update readme", "0.4.8": "use 0.3 overlap for inference", "0.4.7": "update tensorrt benchmark results", "0.4.6": "add tensorrt benchmark result and remove the metric part", diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py index 57eaac0e..345b1a89 100644 --- a/models/vista3d/scripts/inferer.py +++ b/models/vista3d/scripts/inferer.py @@ -17,6 +17,7 @@ from monai.inferers import Inferer, SlidingWindowInfererAdapt from torch import Tensor + class Vista3dInferer(Inferer): """ Vista3D Inferer @@ -100,14 +101,17 @@ def __call__( ) else: val_outputs = SlidingWindowInfererAdapt( - roi_size=self.roi_size, - sw_batch_size=self.sw_batch_size, - with_coord=True)(inputs,network,transpose=True, + roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True + )( + inputs, + network, + transpose=True, point_coords=point_coords, point_labels=point_labels, class_vector=class_vector, prompt_class=prompt_class, prev_mask=prev_mask, labels=labels, - label_set=label_set) + label_set=label_set, + ) return val_outputs From 5f59db3f35ac57be968bd03c9eca1c0153478b00 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 02:33:44 +0000 Subject: [PATCH 09/15] minor change on readme Signed-off-by: Yiheng Wang --- models/vista3d/docs/README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index 6ba50472..b1862f3f 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -164,12 +164,18 @@ All the configurations for inference is stored in inference.json, change those p - The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which matches the `points`. 0 means background, 1 means foreground, -1 means ignoring this point. `points` and `point_labels` must pe provided together and match length. - **B must be 1 if label_prompt and points are provided together**. The inferer only supports SINGLE OBJECT point click segmentatation. -- If no prompt is provided, the model will use `everything_labels` to segment 118 classes: list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132])). -- The `points` together with `label_prompts` for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use `points` for the sub-categories as defined in the inference.json. +- If no prompt is provided, the model will use `everything_labels` to segment 118 classes: + +```Python +list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132])) +``` + +- The `points` together with `label_prompts` for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use `points` for the sub-categories as defined in the `inference.json`. ### `label_prompt` and `label_dict` The `label_dict` defined in [`labels.json`](../docs/labels.json) has in total 132 classes. However, there are 5 we do not support and we keep them due to legacy issue. So in total VISTA3D support 127 classes. + ``` "16, # prostate or uterus" since we already have "prostate" class, "18, # rectum", insufficient data or dataset excluded. @@ -177,8 +183,8 @@ VISTA3D support 127 classes. "129, # kidney mass" insufficient data or dataset excluded. "131, # vertebrae L6", insufficient data or dataset excluded. ``` -These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict. +These 5 are excluded in the `everything_labels`. Another 7 tumor and vessel classes are also removed since they will overlap with other organs and make the output messy. To segment those 7 classes, we recommend users to directly set `label_prompt` to those indexes and avoid using them in `everything_labels`. For "Kidney", "Lung", "Bone" (class index [2, 20, 21]), VISTA3D did not directly use the class index for segmentation, but instead convert them to their subclass indexes as defined by `subclass` dict. For example, "2-Kidney" is converted to "14-Left Kidney" + "5-Right Kidney" since "2" is defined in `subclasss` dict. ``` Note: if the finetuning mapped the local user data index to global index "2, 20, 21", remove the `subclass` dict from inference.json since those values defined in `subclass` will trigger the wrong subclass segmentation. From 42ad2e435ce25109799b7134c487a51b2beccf15 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 02:50:27 +0000 Subject: [PATCH 10/15] sset fallback for trt Signed-off-by: Yiheng Wang --- models/vista3d/configs/inference_trt.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/vista3d/configs/inference_trt.json b/models/vista3d/configs/inference_trt.json index 9c3d52dc..d399d919 100644 --- a/models/vista3d/configs/inference_trt.json +++ b/models/vista3d/configs/inference_trt.json @@ -3,7 +3,8 @@ "$from monai.networks import trt_compile" ], "trt_args": { - "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]" + "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]", + "fallback": true }, "network": "$trt_compile(@network_def.to(@device), @bundle_root + '/models/model.pt', args=@trt_args, submodule=['image_encoder.encoder', 'class_head'])" } From 14d2f1c469b6ed6fb1ec2a700292e065fa0253bb Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 07:44:46 +0000 Subject: [PATCH 11/15] update trt script and readme Signed-off-by: Yiheng Wang --- models/vista3d/configs/inference_trt.json | 6 ++++-- models/vista3d/docs/README.md | 9 ++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/models/vista3d/configs/inference_trt.json b/models/vista3d/configs/inference_trt.json index d399d919..02fb2fe9 100644 --- a/models/vista3d/configs/inference_trt.json +++ b/models/vista3d/configs/inference_trt.json @@ -1,10 +1,12 @@ { + "max_dynamic_batchsize": 4, + "enable_class_head": false, "+imports": [ "$from monai.networks import trt_compile" ], "trt_args": { - "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]", + "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @max_dynamic_batchsize]", "fallback": true }, - "network": "$trt_compile(@network_def.to(@device), @bundle_root + '/models/model.pt', args=@trt_args, submodule=['image_encoder.encoder', 'class_head'])" + "network": "$trt_compile(@network_def.to(@device), @bundle_root + '/models/model.pt', args=@trt_args, submodule=['image_encoder.encoder', 'class_head'] if @enable_class_head else ['image_encoder.encoder'])" } diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index b1862f3f..a94b2c21 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -164,7 +164,7 @@ All the configurations for inference is stored in inference.json, change those p - The `points` is of shape `[N, 3]` like `[[x1,y1,z1],[x2,y2,z2],...[xN,yN,zN]]`, representing `N` point coordinates **IN THE ORIGINAL IMAGE SPACE** of a single foreground object. `point_labels` is a list of length [N] like [1,1,0,-1,...], which matches the `points`. 0 means background, 1 means foreground, -1 means ignoring this point. `points` and `point_labels` must pe provided together and match length. - **B must be 1 if label_prompt and points are provided together**. The inferer only supports SINGLE OBJECT point click segmentatation. -- If no prompt is provided, the model will use `everything_labels` to segment 118 classes: +- If no prompt is provided, the model will use `everything_labels` to segment 117 classes: ```Python list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132])) @@ -225,6 +225,13 @@ This default is overridable by changing the input folder `input_dir`, or the inp python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']" ``` +By default, the argument `enable_class_head` is set to `false` in `configs/inference_trt.json`. This means that the `class_head` module of the network will not be converted into a TensorRT model. Setting this to `true` may accelerate the process, but there are some limitations: + +The `label_prompt` will be converted into a tensor and input into the `class_head` module. The batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). + +To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_dynamic_batchsize` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_dynamic_batchsize`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_dynamic_batchsize` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a suitable value according to your actual requirements. + + ### TroubleShoot for Out-of-Memory - Changing `patch_size` to a smaller value such as `"patch_size": [96, 96, 96]` would reduce the training/inference memory footprint. - Changing `train_dataset_cache_rate` and `val_dataset_cache_rate` to a smaller value like `0.1` can solve the out-of-cpu memory issue when using huge finetuning dataset. From 08c84ab04b33da91f58982835b72c64747d983a5 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 09:42:20 +0000 Subject: [PATCH 12/15] update readme and trt Signed-off-by: Yiheng Wang --- models/vista3d/configs/inference_trt.json | 18 ++++++++++++------ models/vista3d/docs/README.md | 6 ++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/models/vista3d/configs/inference_trt.json b/models/vista3d/configs/inference_trt.json index 02fb2fe9..f13b3468 100644 --- a/models/vista3d/configs/inference_trt.json +++ b/models/vista3d/configs/inference_trt.json @@ -1,12 +1,18 @@ { - "max_dynamic_batchsize": 4, - "enable_class_head": false, "+imports": [ "$from monai.networks import trt_compile" ], - "trt_args": { - "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @max_dynamic_batchsize]", - "fallback": true + "max_prompt_size": 8, + "head_trt_enabled": false, + "network_trt_args": { + "dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]" }, - "network": "$trt_compile(@network_def.to(@device), @bundle_root + '/models/model.pt', args=@trt_args, submodule=['image_encoder.encoder', 'class_head'] if @enable_class_head else ['image_encoder.encoder'])" + "network_dev": "$@network_def.to(@device)", + "encoder": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt', args=@network_trt_args, submodule=['image_encoder.encoder'])", + "head_trt_args": { + "dynamic_batchsize": "$[1, 1, @max_prompt_size]", + "fallback": "$True" + }, + "head": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt', args=@head_trt_args, submodule=['class_head']) if @head_trt_enabled else @network_dev", + "network": "$None if @encoder is None else @head" } diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index a94b2c21..8a78c0ad 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -225,11 +225,9 @@ This default is overridable by changing the input folder `input_dir`, or the inp python -m monai.bundle run --config_file "['configs/inference.json', 'configs/inference_trt.json']" ``` -By default, the argument `enable_class_head` is set to `false` in `configs/inference_trt.json`. This means that the `class_head` module of the network will not be converted into a TensorRT model. Setting this to `true` may accelerate the process, but there are some limitations: +By default, the argument `head_trt_enabled` is set to `false` in `configs/inference_trt.json`. This means that the `class_head` module of the network will not be converted into a TensorRT model. Setting this to `true` may accelerate the process, but there are some limitations: -The `label_prompt` will be converted into a tensor and input into the `class_head` module. The batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). - -To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_dynamic_batchsize` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_dynamic_batchsize`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_dynamic_batchsize` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a suitable value according to your actual requirements. +The `label_prompt` will be converted into a tensor and input into the `class_head` module. The batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_prompt_size` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_prompt_size`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_prompt_size` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a suitable value according to your actual requirements. ### TroubleShoot for Out-of-Memory From c72956b8e27dd2507a4f769a6ea98cac78345f62 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Fri, 20 Sep 2024 09:53:25 +0000 Subject: [PATCH 13/15] update readme Signed-off-by: Yiheng Wang --- models/vista3d/docs/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md index 8a78c0ad..a460ca11 100644 --- a/models/vista3d/docs/README.md +++ b/models/vista3d/docs/README.md @@ -227,7 +227,7 @@ python -m monai.bundle run --config_file "['configs/inference.json', 'configs/in By default, the argument `head_trt_enabled` is set to `false` in `configs/inference_trt.json`. This means that the `class_head` module of the network will not be converted into a TensorRT model. Setting this to `true` may accelerate the process, but there are some limitations: -The `label_prompt` will be converted into a tensor and input into the `class_head` module. The batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_prompt_size` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_prompt_size`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_prompt_size` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a suitable value according to your actual requirements. +Since the `label_prompt` will be converted into a tensor and input into the `class_head` module, the batch size of this input tensor will equal the length of the original `label_prompt` list (if no prompt is provided, the length is 117). To make the TensorRT model work on the `class_head` module, you should set a suitable dynamic batch size range. The maximum dynamic batch size can be configured using the argument `max_prompt_size` in `configs/inference_trt.json`. If the length of the `label_prompt` list exceeds `max_prompt_size`, the engine will fall back to using the normal PyTorch model for inference. Setting a larger `max_prompt_size` can cover more input cases but may require more GPU memory (the default value is 4, which requires 16 GB of GPU memory). Therefore, please set it to a reasonable value according to your actual requirements. ### TroubleShoot for Out-of-Memory From 2b01a0b3ecb5ba18c917b24ae6fda65dd13c7229 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 23 Sep 2024 03:26:08 +0000 Subject: [PATCH 14/15] change vista2d Signed-off-by: Yiheng Wang --- models/vista2d/configs/metadata.json | 3 ++- models/vista2d/docs/README.md | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/models/vista2d/configs/metadata.json b/models/vista2d/configs/metadata.json index fcec60da..067b579a 100644 --- a/models/vista2d/configs/metadata.json +++ b/models/vista2d/configs/metadata.json @@ -1,7 +1,8 @@ { "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json", - "version": "0.2.7", + "version": "0.2.8", "changelog": { + "0.2.8": "remove relative path in readme", "0.2.7": "enhance readme", "0.2.6": "update tensorrt benchmark results", "0.2.5": "add tensorrt benchmark results", diff --git a/models/vista2d/docs/README.md b/models/vista2d/docs/README.md index 4fd639a5..4ff5cfa1 100644 --- a/models/vista2d/docs/README.md +++ b/models/vista2d/docs/README.md @@ -2,7 +2,7 @@ The **VISTA2D** is a cell segmentation training and inference pipeline for cell imaging [[`Blog`](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)]. -A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the [tutorial](../download_preprocessor/). Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html). +A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the `download_preprocessor/`. Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).
From 66c65e77c773dbfc22b6df7df69c4ca740473908 Mon Sep 17 00:00:00 2001 From: Yiheng Wang Date: Mon, 23 Sep 2024 03:40:09 +0000 Subject: [PATCH 15/15] use tutorial labels dict Signed-off-by: Yiheng Wang --- models/vista3d/docs/labels.json | 134 -------------------------------- models/vista3d/large_files.yml | 2 + 2 files changed, 2 insertions(+), 134 deletions(-) delete mode 100644 models/vista3d/docs/labels.json diff --git a/models/vista3d/docs/labels.json b/models/vista3d/docs/labels.json deleted file mode 100644 index 51f93162..00000000 --- a/models/vista3d/docs/labels.json +++ /dev/null @@ -1,134 +0,0 @@ -{ - "liver": 1, - "kidney": 2, - "spleen": 3, - "pancreas": 4, - "right kidney": 5, - "aorta": 6, - "inferior vena cava": 7, - "right adrenal gland": 8, - "left adrenal gland": 9, - "gallbladder": 10, - "esophagus": 11, - "stomach": 12, - "duodenum": 13, - "left kidney": 14, - "bladder": 15, - "prostate or uterus": 16, - "portal vein and splenic vein": 17, - "rectum": 18, - "small bowel": 19, - "lung": 20, - "bone": 21, - "brain": 22, - "lung tumor": 23, - "pancreatic tumor": 24, - "hepatic vessel": 25, - "hepatic tumor": 26, - "colon cancer primaries": 27, - "left lung upper lobe": 28, - "left lung lower lobe": 29, - "right lung upper lobe": 30, - "right lung middle lobe": 31, - "right lung lower lobe": 32, - "vertebrae L5": 33, - "vertebrae L4": 34, - "vertebrae L3": 35, - "vertebrae L2": 36, - "vertebrae L1": 37, - "vertebrae T12": 38, - "vertebrae T11": 39, - "vertebrae T10": 40, - "vertebrae T9": 41, - "vertebrae T8": 42, - "vertebrae T7": 43, - "vertebrae T6": 44, - "vertebrae T5": 45, - "vertebrae T4": 46, - "vertebrae T3": 47, - "vertebrae T2": 48, - "vertebrae T1": 49, - "vertebrae C7": 50, - "vertebrae C6": 51, - "vertebrae C5": 52, - "vertebrae C4": 53, - "vertebrae C3": 54, - "vertebrae C2": 55, - "vertebrae C1": 56, - "trachea": 57, - "left iliac artery": 58, - "right iliac artery": 59, - "left iliac vena": 60, - "right iliac vena": 61, - "colon": 62, - "left rib 1": 63, - "left rib 2": 64, - "left rib 3": 65, - "left rib 4": 66, - "left rib 5": 67, - "left rib 6": 68, - "left rib 7": 69, - "left rib 8": 70, - "left rib 9": 71, - "left rib 10": 72, - "left rib 11": 73, - "left rib 12": 74, - "right rib 1": 75, - "right rib 2": 76, - "right rib 3": 77, - "right rib 4": 78, - "right rib 5": 79, - "right rib 6": 80, - "right rib 7": 81, - "right rib 8": 82, - "right rib 9": 83, - "right rib 10": 84, - "right rib 11": 85, - "right rib 12": 86, - "left humerus": 87, - "right humerus": 88, - "left scapula": 89, - "right scapula": 90, - "left clavicula": 91, - "right clavicula": 92, - "left femur": 93, - "right femur": 94, - "left hip": 95, - "right hip": 96, - "sacrum": 97, - "left gluteus maximus": 98, - "right gluteus maximus": 99, - "left gluteus medius": 100, - "right gluteus medius": 101, - "left gluteus minimus": 102, - "right gluteus minimus": 103, - "left autochthon": 104, - "right autochthon": 105, - "left iliopsoas": 106, - "right iliopsoas": 107, - "left atrial appendage": 108, - "brachiocephalic trunk": 109, - "left brachiocephalic vein": 110, - "right brachiocephalic vein": 111, - "left common carotid artery": 112, - "right common carotid artery": 113, - "costal cartilages": 114, - "heart": 115, - "left kidney cyst": 116, - "right kidney cyst": 117, - "prostate": 118, - "pulmonary vein": 119, - "skull": 120, - "spinal cord": 121, - "sternum": 122, - "left subclavian artery": 123, - "right subclavian artery": 124, - "superior vena cava": 125, - "thyroid gland": 126, - "vertebrae S1": 127, - "bone lesion": 128, - "kidney mass": 129, - "liver tumor": 130, - "vertebrae L6": 131, - "airway": 132 -} diff --git a/models/vista3d/large_files.yml b/models/vista3d/large_files.yml index 031bf2df..586150ff 100644 --- a/models/vista3d/large_files.yml +++ b/models/vista3d/large_files.yml @@ -3,3 +3,5 @@ large_files: url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt" hash_val: "6ce45a8edde4400c5d28d5e74d7b61d5" hash_type: "md5" + - path: "docs/labels.json" + url: "https://github.com/Project-MONAI/tutorials/blob/e66be5955d2b4f5959884ca026932762954b19c5/vista_3d/label_dict.json"