From 9d1b26a1d37100448d2d4e2232f518e9d01a19c9 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:09:51 +0800 Subject: [PATCH] Add vista3d (#605) Fixes #604 . ### Description A few sentences describing the changes proposed in this pull request. ### Status **Ready/Work in progress/Hold** ### Please ensure all the checkboxes: - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- ci/bundle_custom_data.py | 1 + ci/unit_tests/test_vista3d.py | 442 ++++++++++++++++ ci/unit_tests/test_vista3d_mgpu.py | 179 +++++++ ci/unit_tests/utils.py | 1 - models/vista3d/LICENSE | 201 ++++++++ models/vista3d/configs/batch_inference.json | 7 + models/vista3d/configs/data.yaml | 39 ++ models/vista3d/configs/evaluate.json | 177 +++++++ models/vista3d/configs/inference.json | 200 +++++++ models/vista3d/configs/logging.conf | 27 + models/vista3d/configs/metadata.json | 210 ++++++++ models/vista3d/configs/mgpu_evaluate.json | 29 ++ models/vista3d/configs/multi_gpu_train.json | 42 ++ models/vista3d/configs/train.json | 394 ++++++++++++++ models/vista3d/configs/train_continual.json | 109 ++++ models/vista3d/docs/README.md | 215 ++++++++ models/vista3d/docs/data_license.txt | 6 + models/vista3d/docs/labels.json | 137 +++++ models/vista3d/large_files.yml | 3 + models/vista3d/msd_task09_spleen_folds.json | 271 ++++++++++ models/vista3d/scripts/__init__.py | 16 + .../scripts/early_stop_score_function.py | 15 + models/vista3d/scripts/evaluator.py | 292 +++++++++++ models/vista3d/scripts/inferer.py | 132 +++++ models/vista3d/scripts/monai_trans_utils.py | 317 ++++++++++++ models/vista3d/scripts/monai_utils.py | 412 +++++++++++++++ models/vista3d/scripts/trainer.py | 217 ++++++++ models/vista3d/scripts/utils.py | 470 +++++++++++++++++ models/vista3d/scripts/vista3d/__init__.py | 1 + .../vista3d/scripts/vista3d/build_vista3d.py | 29 ++ .../scripts/vista3d/modeling/__init__.py | 12 + .../scripts/vista3d/modeling/class_head.py | 51 ++ .../scripts/vista3d/modeling/point_head.py | 113 ++++ .../scripts/vista3d/modeling/sam_blocks.py | 292 +++++++++++ .../scripts/vista3d/modeling/segresnetds.py | 488 ++++++++++++++++++ .../scripts/vista3d/modeling/vista3d.py | 262 ++++++++++ 36 files changed, 5808 insertions(+), 1 deletion(-) create mode 100644 ci/unit_tests/test_vista3d.py create mode 100644 ci/unit_tests/test_vista3d_mgpu.py create mode 100644 models/vista3d/LICENSE create mode 100644 models/vista3d/configs/batch_inference.json create mode 100644 models/vista3d/configs/data.yaml create mode 100644 models/vista3d/configs/evaluate.json create mode 100644 models/vista3d/configs/inference.json create mode 100644 models/vista3d/configs/logging.conf create mode 100644 models/vista3d/configs/metadata.json create mode 100644 models/vista3d/configs/mgpu_evaluate.json create mode 100644 models/vista3d/configs/multi_gpu_train.json create mode 100644 models/vista3d/configs/train.json create mode 100644 models/vista3d/configs/train_continual.json create mode 100644 models/vista3d/docs/README.md create mode 100644 models/vista3d/docs/data_license.txt create mode 100644 models/vista3d/docs/labels.json create mode 100644 models/vista3d/large_files.yml create mode 100644 models/vista3d/msd_task09_spleen_folds.json create mode 100644 models/vista3d/scripts/__init__.py create mode 100644 models/vista3d/scripts/early_stop_score_function.py create mode 100644 models/vista3d/scripts/evaluator.py create mode 100644 models/vista3d/scripts/inferer.py create mode 100644 models/vista3d/scripts/monai_trans_utils.py create mode 100644 models/vista3d/scripts/monai_utils.py create mode 100644 models/vista3d/scripts/trainer.py create mode 100644 models/vista3d/scripts/utils.py create mode 100644 models/vista3d/scripts/vista3d/__init__.py create mode 100755 models/vista3d/scripts/vista3d/build_vista3d.py create mode 100755 models/vista3d/scripts/vista3d/modeling/__init__.py create mode 100644 models/vista3d/scripts/vista3d/modeling/class_head.py create mode 100644 models/vista3d/scripts/vista3d/modeling/point_head.py create mode 100644 models/vista3d/scripts/vista3d/modeling/sam_blocks.py create mode 100644 models/vista3d/scripts/vista3d/modeling/segresnetds.py create mode 100644 models/vista3d/scripts/vista3d/modeling/vista3d.py diff --git a/ci/bundle_custom_data.py b/ci/bundle_custom_data.py index 711bce3f..2965350c 100644 --- a/ci/bundle_custom_data.py +++ b/ci/bundle_custom_data.py @@ -36,6 +36,7 @@ "breast_density_classification", "mednist_reg", "brats_mri_axial_slices_generative_diffusion", + "vista3d", ] # This dict is used for our CI tests to install required dependencies that cannot be installed by `pip install` directly. diff --git a/ci/unit_tests/test_vista3d.py b/ci/unit_tests/test_vista3d.py new file mode 100644 index 00000000..b12d3e69 --- /dev/null +++ b/ci/unit_tests/test_vista3d.py @@ -0,0 +1,442 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from monai.bundle import ConfigWorkflow +from parameterized import parameterized +from utils import check_workflow + +TEST_CASE_INFER = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "points": [[123, 212, 151]], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_STR_PROMPT = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": ["spleen"], "points": [[123, 212, 151]], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_MULTI_PROMPT = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25, 24, 1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_MULTI_STR_PROMPT = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": ["hepatic vessel", "pancreatic tumor", "liver"]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_MULTI_NEW_STR_PROMPT = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": ["new class 1", "new class 2", "new class 3"]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_SUBCLASS = [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [2, 20, 21]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_INFER_NO_PROMPT = [ + { + "bundle_root": "models/vista3d", + "input_dict": {}, # put an empty dict, and will add an image in the test function + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_EVAL = [ + { + "bundle_root": "models/vista3d", + "patch_size": [32, 32, 32], + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_TRAIN = [ + { + "bundle_root": "models/vista3d", + "patch_size": [32, 32, 32], + "epochs": 2, + "val_interval": 1, + "initialize": ["$monai.utils.set_determinism(seed=123)"], + } +] +TEST_CASE_TRAIN_CONTINUAL = [ + { + "bundle_root": "models/vista3d", + "patch_size": [32, 32, 32], + "epochs": 2, + "val_interval": 1, + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "finetune": False, + } +] +TEST_CASE_ERROR_PROMPTS = [ + [ + { + "bundle_root": "models/vista3d", + "input_dict": {}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "everything_labels": None, + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Prompt must be given for inference.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [[25, 26, 27]]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Label prompt must be a list of single scalar, [1,2,3,4,...,].", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": 25}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Label prompt must be a list, [1,2,3,4,...,].", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [256]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Current bundle only supports label prompt smaller than 255.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "points": [[123, 212, 151]]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Point labels must be given if points are given.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Points must be given if point labels are given.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "points": [[1, 123, 212, 151]], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "points": [[[123, 212, 151]]], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25], "points": [[123, 212, 151]], "point_labels": [1, 1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Points must match point labels.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [1], "points": [[123, 212, 151]], "point_labels": [-2]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Point labels can only be -1,0,1 and 2,3 for special flags.", + } + ], + [ + { + "bundle_root": "models/vista3d", + "input_dict": {"label_prompt": [25, 26], "points": [[123, 212, 151]], "point_labels": [1]}, + "patch_size": [32, 32, 32], + "checkpointloader#_disabled_": True, # do not load weights" + "initialize": ["$monai.utils.set_determinism(seed=123)"], + "error": "Label prompt can only be a single object if provided with point prompts.", + } + ], +] + + +def test_order(test_name1, test_name2): + def get_order(name): + if "train_config" in name: + return 1 + if "train_continual" in name: + return 2 + if "eval" in name: + return 3 + return 4 + + return get_order(test_name1) - get_order(test_name2) + + +class TestVista3d(unittest.TestCase): + def setUp(self): + self.dataset_dir = tempfile.mkdtemp() + self.dataset_size = 5 + input_shape = (64, 64, 64) + for s in range(self.dataset_size): + test_image = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + test_label = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + image_filename = os.path.join(self.dataset_dir, f"image_{s}.nii.gz") + label_filename = os.path.join(self.dataset_dir, f"label_{s}.nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), image_filename) + nib.save(nib.Nifti1Image(test_label, np.eye(4)), label_filename) + + def tearDown(self): + shutil.rmtree(self.dataset_dir) + + @parameterized.expand([TEST_CASE_TRAIN]) + def test_train_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + trainer = ConfigWorkflow( + workflow_type="train", + config_file=os.path.join(bundle_root, "configs/train.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + check_workflow(trainer, check_properties=False) + + @parameterized.expand([TEST_CASE_EVAL]) + def test_eval_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + config_files = [ + os.path.join(bundle_root, "configs/train.json"), + os.path.join(bundle_root, "configs/train_continual.json"), + os.path.join(bundle_root, "configs/evaluate.json"), + os.path.join(bundle_root, "configs/data.yaml"), + ] + trainer = ConfigWorkflow( + workflow_type="train", + config_file=config_files, + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + check_workflow(trainer, check_properties=False) + + @parameterized.expand([TEST_CASE_TRAIN_CONTINUAL]) + def test_train_continual_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + trainer = ConfigWorkflow( + workflow_type="train", + config_file=[ + os.path.join(bundle_root, "configs/train.json"), + os.path.join(bundle_root, "configs/train_continual.json"), + ], + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + check_workflow(trainer, check_properties=False) + + @parameterized.expand( + [ + TEST_CASE_INFER, + TEST_CASE_INFER_MULTI_PROMPT, + TEST_CASE_INFER_NO_PROMPT, + TEST_CASE_INFER_SUBCLASS, + TEST_CASE_INFER_STR_PROMPT, + TEST_CASE_INFER_MULTI_STR_PROMPT, + TEST_CASE_INFER_MULTI_NEW_STR_PROMPT, + ] + ) + def test_infer_config(self, override): + # update input_dict with dataset dir + input_dict = override["input_dict"] + input_dict["image"] = os.path.join(self.dataset_dir, "image_0.nii.gz") + override["input_dict"] = input_dict + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + + inferrer = ConfigWorkflow( + workflow_type="infer", + config_file=os.path.join(bundle_root, "configs/inference.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + # check_properties=False because this bundle does not have some required properties such as dataset_dir + check_workflow(inferrer, check_properties=False) + + @parameterized.expand( + [TEST_CASE_INFER, TEST_CASE_INFER_MULTI_PROMPT, TEST_CASE_INFER_NO_PROMPT, TEST_CASE_INFER_SUBCLASS] + ) + def test_batch_infer_config(self, override): + # update input_dict with dataset dir + params = override.copy() + params.pop("input_dict", None) + params["input_dir"] = self.dataset_dir + params["input_suffix"] = "image_*.nii.gz" + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + config_files = [ + os.path.join(bundle_root, "configs/inference.json"), + os.path.join(bundle_root, "configs/batch_inference.json"), + ] + inferrer = ConfigWorkflow( + workflow_type="infer", + config_file=config_files, + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **params, + ) + # check_properties=False because this bundle does not have some required properties such as dataset_dir + check_workflow(inferrer, check_properties=False) + + @parameterized.expand(TEST_CASE_ERROR_PROMPTS) + def test_error_prompt_infer_config(self, override): + # update input_dict with dataset dir + input_dict = override["input_dict"] + input_dict["image"] = os.path.join(self.dataset_dir, "image_0.nii.gz") + override["input_dict"] = input_dict + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + + inferrer = ConfigWorkflow( + workflow_type="infer", + config_file=os.path.join(bundle_root, "configs/inference.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + inferrer.initialize() + with self.assertRaises(RuntimeError) as context: + inferrer.run() + runtime_error = context.exception + original_exception = runtime_error.__cause__ + self.assertEqual(str(original_exception), override["error"]) + + +if __name__ == "__main__": + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = test_order + unittest.main(testLoader=loader) diff --git a/ci/unit_tests/test_vista3d_mgpu.py b/ci/unit_tests/test_vista3d_mgpu.py new file mode 100644 index 00000000..7d125e21 --- /dev/null +++ b/ci/unit_tests/test_vista3d_mgpu.py @@ -0,0 +1,179 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +import torch +from parameterized import parameterized +from utils import export_config_and_run_mgpu_cmd + +TEST_CASE_TRAIN_MGPU = [{"bundle_root": "models/vista3d", "patch_size": [32, 32, 32], "epochs": 2, "val_interval": 1}] + +TEST_CASE_EVAL_MGPU = [{"bundle_root": "models/vista3d", "patch_size": [32, 32, 32]}] + +TEST_CASE_TRAIN_CONTINUAL = [ + {"bundle_root": "models/vista3d", "patch_size": [32, 32, 32], "epochs": 2, "val_interval": 1, "finetune": False} +] + + +def test_order(test_name1, test_name2): + def get_order(name): + if "train_mgpu" in name: + return 1 + if "train_continual" in name: + return 2 + if "eval" in name: + return 3 + return 4 + + return get_order(test_name1) - get_order(test_name2) + + +class TestVista3d(unittest.TestCase): + def setUp(self): + self.dataset_dir = tempfile.mkdtemp() + self.dataset_size = 5 + input_shape = (64, 64, 64) + for s in range(self.dataset_size): + test_image = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + test_label = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + image_filename = os.path.join(self.dataset_dir, f"image_{s}.nii.gz") + label_filename = os.path.join(self.dataset_dir, f"label_{s}.nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), image_filename) + nib.save(nib.Nifti1Image(test_label, np.eye(4)), label_filename) + + def tearDown(self): + shutil.rmtree(self.dataset_dir) + + @parameterized.expand([TEST_CASE_TRAIN_MGPU]) + def test_train_mgpu_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + train_file = os.path.join(bundle_root, "configs/train.json") + mgpu_train_file = os.path.join(bundle_root, "configs/multi_gpu_train.json") + output_path = os.path.join(bundle_root, "configs/train_override.json") + n_gpu = torch.cuda.device_count() + export_config_and_run_mgpu_cmd( + config_file=[train_file, mgpu_train_file], + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + override_dict=override, + output_path=output_path, + ngpu=n_gpu, + ) + + @parameterized.expand([TEST_CASE_EVAL_MGPU]) + def test_eval_mgpu_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + config_files = [ + os.path.join(bundle_root, "configs/train.json"), + os.path.join(bundle_root, "configs/train_continual.json"), + os.path.join(bundle_root, "configs/evaluate.json"), + os.path.join(bundle_root, "configs/mgpu_evaluate.json"), + os.path.join(bundle_root, "configs/data.yaml"), + ] + output_path = os.path.join(bundle_root, "configs/evaluate_override.json") + n_gpu = torch.cuda.device_count() + export_config_and_run_mgpu_cmd( + config_file=config_files, + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + override_dict=override, + output_path=output_path, + ngpu=n_gpu, + ) + + @parameterized.expand([TEST_CASE_TRAIN_CONTINUAL]) + def test_train_continual_config(self, override): + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size) + ] + val_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + } + for i in range(train_size, self.dataset_size) + ] + override["train_datalist"] = train_datalist + override["val_datalist"] = val_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + config_files = [ + os.path.join(bundle_root, "configs/train.json"), + os.path.join(bundle_root, "configs/train_continual.json"), + os.path.join(bundle_root, "configs/multi_gpu_train.json"), + ] + output_path = os.path.join(bundle_root, "configs/train_continual_override.json") + n_gpu = torch.cuda.device_count() + export_config_and_run_mgpu_cmd( + config_file=config_files, + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + override_dict=override, + output_path=output_path, + ngpu=n_gpu, + ) + + +if __name__ == "__main__": + loader = unittest.TestLoader() + loader.sortTestMethodsUsing = test_order + unittest.main(testLoader=loader) diff --git a/ci/unit_tests/utils.py b/ci/unit_tests/utils.py index ebd1137a..7fda766d 100644 --- a/ci/unit_tests/utils.py +++ b/ci/unit_tests/utils.py @@ -20,7 +20,6 @@ def export_overrided_config(config_file, override_dict, output_path): parser = ConfigParser() parser.read_config(config_file) parser.update(pairs=override_dict) - ConfigParser.export_config_file(parser.config, output_path, indent=4) diff --git a/models/vista3d/LICENSE b/models/vista3d/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/models/vista3d/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/models/vista3d/configs/batch_inference.json b/models/vista3d/configs/batch_inference.json new file mode 100644 index 00000000..605e66d6 --- /dev/null +++ b/models/vista3d/configs/batch_inference.json @@ -0,0 +1,7 @@ +{ + "input_dir": "@bundle_root", + "input_suffix": "*.nii.gz", + "input_list": "$sorted(glob.glob(os.path.join(@input_dir, @input_suffix)))", + "input_dicts": "$[{'image': x, 'label_prompt': @everything_labels} for x in @input_list]", + "dataset#data": "@input_dicts" +} diff --git a/models/vista3d/configs/data.yaml b/models/vista3d/configs/data.yaml new file mode 100644 index 00000000..c379f363 --- /dev/null +++ b/models/vista3d/configs/data.yaml @@ -0,0 +1,39 @@ +# this file shows specific settings for evaluation on each dataset +validate#postprocessing#transforms#0#_disabled_: false +validate#handlers#2#_disabled_: true +validate#evaluator#key_val_metric: null +eval_folder: "/eval/" +ckpt_path: "$@bundle_root + '/models/model.pt'" +label_set: "$list(set([i+1 for i in range(117)]) - set([22, 23, 15, 25, 19, 2, 26, 27, 28, 29, 117]))" +val_label_set: "$list(range(118))" +label_mappings: "${}" +# label_mappings: +# default: +# - [1, 25] # lung tumor +# - [2, 26] # pancreatic tumor +# - [1, 27] # hepatic vessel +# - [2, 28] # hepatic tumor +# - [1, 29] # colon tumor +# - [1, 117] # bone tumor +# - [2, 117] # bone tumor +# - [10, 1] # liver +# - [12, 3] # spleen +# - [13, 4] # pancreas +# - [15, 30] # left lung upper lobe +# - [16, 31] # left lung lower lobe +# - [17, 32] # right lung upper lobe +# - [18, 33] # right lung middle lobe +# - [19, 34] # right lung lower lobe +# - [20, 5] # right kidney +# - [21, 14] # left kidney +# - [22, 71] # left rib 1 +# - [23, 72] # left rib 2 +# - [24, 73] +# - [25, 74] +# - [26, 75] +# - [27, 76] +# - [28, 77] +# - [29, 79] +# - [30, 80] +# - [31, 81] +# - [32, 82] # left rib 12 diff --git a/models/vista3d/configs/evaluate.json b/models/vista3d/configs/evaluate.json new file mode 100644 index 00000000..176cfd08 --- /dev/null +++ b/models/vista3d/configs/evaluate.json @@ -0,0 +1,177 @@ +{ + "data_list_file_path": "$@bundle_root + '/msd_task09_spleen_folds.json'", + "dataset_dir": "/data/Task09_Spleen", + "output_dir": "$@bundle_root + '/eval'", + "ckpt_path": "$@bundle_root + '/models/model.pt'", + "val_dataset_cache_rate": 0.0, + "patch_size": [ + 128, + 128, + 128 + ], + "resample_to_spacing": [ + 1.5, + 1.5, + 1.5 + ], + "cache_cls_idx#activate": false, + "label_mappings": { + "default": [ + [ + 1, + 25 + ] + ] + }, + "label_set": "$[0] + list(x[1] for x in @label_mappings#default)", + "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)", + "num_classes": 118, + "output_classes": 118, + "validate#preprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image", + "label" + ], + "image_only": true, + "ensure_channel_first": true + }, + { + "_target_": "CropForegroundd", + "keys": [ + "image" + ], + "source_key": "image", + "margin": 10, + "allow_smaller": true, + "start_coord_key": null, + "end_coord_key": null + }, + { + "_target_": "ScaleIntensityRanged", + "keys": "image", + "a_min": -963.8247715525971, + "a_max": 1053.678477684517, + "b_min": 0.0, + "b_max": 1.0, + "clip": true + }, + { + "_target_": "Orientationd", + "keys": [ + "image" + ], + "axcodes": "RAS" + }, + { + "_target_": "Spacingd", + "keys": [ + "image" + ], + "pixdim": "$@resample_to_spacing", + "mode": [ + "bilinear" + ] + }, + { + "_target_": "CastToTyped", + "keys": [ + "image", + "label" + ], + "dtype": [ + "$torch.float32", + "$torch.uint8" + ] + }, + { + "_target_": "scripts.monai_trans_utils.RelabelD", + "keys": "label", + "label_mappings": "@label_mappings", + "dtype": "$torch.uint8" + } + ] + }, + "validate#postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "EnsureTyped", + "keys": [ + "pred", + "label" + ], + "device": "cpu", + "_disabled_": true + }, + { + "_target_": "Activationsd", + "keys": "pred", + "sigmoid": true + }, + { + "_target_": "scripts.monai_trans_utils.VistaPostTransform", + "keys": "pred" + }, + { + "_target_": "Invertd", + "keys": "pred", + "transform": "@validate#preprocessing", + "orig_keys": "image", + "nearest_interp": true, + "to_tensor": true + }, + { + "_target_": "Lambdad", + "func": "$lambda x: torch.nan_to_num(x, nan=255)", + "keys": "pred" + }, + { + "_target_": "SaveImaged", + "keys": "pred", + "resample": false, + "output_dir": "@output_dir" + } + ] + }, + "validate#handlers": [ + { + "_target_": "CheckpointLoader", + "load_path": "@ckpt_path", + "load_dict": { + "model": "@network" + } + }, + { + "_target_": "StatsHandler", + "iteration_log": true, + "name": "validate_stats" + }, + { + "_target_": "MetricsSaver", + "_disabled_": false, + "save_dir": "@output_dir", + "metrics": [ + "val_mean_dice" + ], + "batch_transform": "$lambda x: [xx['image'].meta for xx in x]", + "metric_details": "*", + "summary_ops": "*" + } + ], + "validate#dataset": { + "_target_": "CacheDataset", + "data": "$list(@val_datalist)+list(@train_datalist)", + "transform": "@validate#preprocessing", + "cache_rate": "@val_dataset_cache_rate", + "hash_as_key": true, + "num_workers": "@num_cache_workers", + "progress": "@show_cache_progress" + }, + "run": [ + "$@validate#evaluator.run()" + ] +} diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json new file mode 100644 index 00000000..3b5ebe59 --- /dev/null +++ b/models/vista3d/configs/inference.json @@ -0,0 +1,200 @@ +{ + "imports": [ + "$import glob", + "$import os", + "$import scripts", + "$import numpy as np" + ], + "bundle_root": "./", + "image_key": "image", + "output_dir": "$@bundle_root + '/eval'", + "output_ext": ".nii.gz", + "output_dtype": "$np.float32", + "output_postfix": "trans", + "separate_folder": true, + "input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}", + "everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))", + "subclass": { + "2": [ + 14, + 5 + ], + "20": [ + 28, + 29, + 30, + 31, + 32 + ], + "21": "$list(range(33, 57)) + list(range(63, 98)) + [114, 120, 122]" + }, + "input_channels": 1, + "resample_spacing": [ + 1.5, + 1.5, + 1.5 + ], + "sw_batch_size": 1, + "patch_size": [ + 128, + 128, + 128 + ], + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "use_cfp": false, + "use_point_window": true, + "network_def": "$scripts.vista3d.vista_model_registry['vista3d_segresnet_d'](in_channels=@input_channels, image_size=@patch_size)", + "network": "$@network_def.to(@device)", + "preprocessing_transforms": [ + { + "_target_": "LoadImaged", + "keys": "@image_key", + "image_only": true + }, + { + "_target_": "EnsureChannelFirstd", + "keys": "@image_key" + }, + { + "_target_": "EnsureTyped", + "keys": "@image_key", + "device": "@device", + "track_meta": true + }, + { + "_target_": "Spacingd", + "keys": "@image_key", + "pixdim": "@resample_spacing", + "mode": "bilinear" + }, + { + "_target_": "CropForegroundd", + "keys": "@image_key", + "allow_smaller": true, + "margin": 10, + "source_key": "@image_key" + }, + { + "_target_": "scripts.monai_trans_utils.VistaPreTransform", + "keys": "@image_key", + "subclass": "@subclass", + "bundle_root": "@bundle_root" + }, + { + "_target_": "ScaleIntensityRanged", + "keys": "@image_key", + "a_min": -963.8247715525971, + "a_max": 1053.678477684517, + "b_min": 0, + "b_max": 1, + "clip": true + }, + { + "_target_": "Orientationd", + "keys": "@image_key", + "axcodes": "RAS" + }, + { + "_target_": "CastToTyped", + "keys": "@image_key", + "dtype": "$torch.float32" + } + ], + "preprocessing": { + "_target_": "Compose", + "transforms": "$@preprocessing_transforms " + }, + "dataset": { + "_target_": "Dataset", + "data": "$[@input_dict]", + "transform": "@preprocessing" + }, + "dataloader": { + "_target_": "ThreadDataLoader", + "dataset": "@dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 0 + }, + "inferer": { + "_target_": "scripts.inferer.Vista3dInferer", + "roi_size": "@patch_size", + "overlap": 0.5, + "sw_batch_size": "@sw_batch_size", + "use_cfp": "@use_cfp", + "use_point_window": "@use_point_window" + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "ToDeviced", + "keys": "pred", + "device": "cpu", + "_disabled_": true + }, + { + "_target_": "scripts.monai_trans_utils.VistaPostTransform", + "keys": "pred" + }, + { + "_target_": "Invertd", + "keys": "pred", + "transform": "@preprocessing", + "orig_keys": "@image_key", + "nearest_interp": true, + "to_tensor": true + }, + { + "_target_": "Lambdad", + "func": "$lambda x: torch.nan_to_num(x, nan=255)", + "keys": "pred" + }, + { + "_target_": "SaveImaged", + "keys": "pred", + "resample": false, + "output_dir": "@output_dir", + "output_ext": "@output_ext", + "output_dtype": "@output_dtype", + "output_postfix": "@output_postfix", + "separate_folder": "@separate_folder" + } + ] + }, + "handlers": [ + { + "_target_": "StatsHandler", + "iteration_log": false + } + ], + "checkpointloader": { + "_target_": "CheckpointLoader", + "load_path": "$@bundle_root + '/models/model.pt'", + "load_dict": { + "model": "@network" + } + }, + "evaluator": { + "_target_": "scripts.evaluator.Vista3dEvaluator", + "device": "@device", + "val_data_loader": "@dataloader", + "network": "@network", + "inferer": "@inferer", + "postprocessing": "@postprocessing", + "val_handlers": "@handlers", + "amp": true, + "hyper_kwargs": { + "use_cfp": "@use_cfp", + "user_prompt": true, + "everything_labels": "@everything_labels" + } + }, + "initialize": [ + "$monai.utils.set_determinism(seed=123)", + "$@checkpointloader(@evaluator)" + ], + "run": [ + "$@evaluator.run()" + ] +} diff --git a/models/vista3d/configs/logging.conf b/models/vista3d/configs/logging.conf new file mode 100644 index 00000000..ad1b962b --- /dev/null +++ b/models/vista3d/configs/logging.conf @@ -0,0 +1,27 @@ +[loggers] +keys=root + +[handlers] +keys=consoleHandler,fileHandler + +[formatters] +keys=fullFormatter + +[logger_root] +level=INFO +handlers=consoleHandler,fileHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=fullFormatter +args=(sys.stdout,) + +[handler_fileHandler] +class=FileHandler +level=INFO +formatter=fullFormatter +args=('training.log',) + +[formatter_fullFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s diff --git a/models/vista3d/configs/metadata.json b/models/vista3d/configs/metadata.json new file mode 100644 index 00000000..377bbfa1 --- /dev/null +++ b/models/vista3d/configs/metadata.json @@ -0,0 +1,210 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json", + "version": "0.4.1", + "changelog": { + "0.4.1": "initial OSS version" + }, + "monai_version": "1.3.1", + "pytorch_version": "2.2.2", + "numpy_version": "1.24.4", + "matplotlib": "3.8.3", + "einops": "0.7.0", + "required_packages_version": { + "scikit-image": "0.22.0", + "nibabel": "5.2.1", + "pytorch-ignite": "0.4.11", + "cucim": "23.08.00" + }, + "supported_apps": { + "vista3d-nim": "" + }, + "name": "VISTA3D", + "task": "Decathlon Spleen segmentation", + "description": "VISTA3D bundle", + "authors": "MONAI team", + "copyright": "Copyright (c) MONAI Consortium", + "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/", + "data_type": "nibabel", + "image_classes": "1 channel data, intensity scaled to [0, 1]", + "label_classes": "single channel data", + "pred_classes": "2 channels OneHot data", + "intended_use": "This is an example, not to be used for diagnostic purposes", + "references": [], + "network_data_format": { + "inputs": { + "image": { + "type": "image", + "format": "hounsfield", + "modality": "CT", + "num_channels": 1, + "spatial_shape": [ + 128, + 128, + 128 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true, + "channel_def": { + "0": "image" + } + } + }, + "outputs": { + "pred": { + "type": "image", + "format": "segmentation", + "num_channels": 1, + "spatial_shape": [ + 128, + 128, + 128 + ], + "dtype": "float32", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true, + "channel_def": { + "0": "background", + "1": "liver", + "2": "kidney", + "3": "spleen", + "4": "pancreas", + "5": "right kidney", + "6": "aorta", + "7": "inferior vena cava", + "8": "right adrenal gland", + "9": "left adrenal gland", + "10": "gallbladder", + "11": "esophagus", + "12": "stomach", + "13": "duodenum", + "14": "left kidney", + "15": "bladder", + "16": "prostate or uterus", + "17": "portal vein and splenic vein", + "18": "rectum", + "19": "small bowel", + "20": "lung", + "21": "bone", + "22": "brain", + "23": "lung tumor", + "24": "pancreatic tumor", + "25": "hepatic vessel", + "26": "hepatic tumor", + "27": "colon cancer primaries", + "28": "left lung upper lobe", + "29": "left lung lower lobe", + "30": "right lung upper lobe", + "31": "right lung middle lobe", + "32": "right lung lower lobe", + "33": "vertebrae L5", + "34": "vertebrae L4", + "35": "vertebrae L3", + "36": "vertebrae L2", + "37": "vertebrae L1", + "38": "vertebrae T12", + "39": "vertebrae T11", + "40": "vertebrae T10", + "41": "vertebrae T9", + "42": "vertebrae T8", + "43": "vertebrae T7", + "44": "vertebrae T6", + "45": "vertebrae T5", + "46": "vertebrae T4", + "47": "vertebrae T3", + "48": "vertebrae T2", + "49": "vertebrae T1", + "50": "vertebrae C7", + "51": "vertebrae C6", + "52": "vertebrae C5", + "53": "vertebrae C4", + "54": "vertebrae C3", + "55": "vertebrae C2", + "56": "vertebrae C1", + "57": "trachea", + "58": "left iliac artery", + "59": "right iliac artery", + "60": "left iliac vena", + "61": "right iliac vena", + "62": "colon", + "63": "left rib 1", + "64": "left rib 2", + "65": "left rib 3", + "66": "left rib 4", + "67": "left rib 5", + "68": "left rib 6", + "69": "left rib 7", + "70": "left rib 8", + "71": "left rib 9", + "72": "left rib 10", + "73": "left rib 11", + "74": "left rib 12", + "75": "right rib 1", + "76": "right rib 2", + "77": "right rib 3", + "78": "right rib 4", + "79": "right rib 5", + "80": "right rib 6", + "81": "right rib 7", + "82": "right rib 8", + "83": "right rib 9", + "84": "right rib 10", + "85": "right rib 11", + "86": "right rib 12", + "87": "left humerus", + "88": "right humerus", + "89": "left scapula", + "90": "right scapula", + "91": "left clavicula", + "92": "right clavicula", + "93": "left femur", + "94": "right femur", + "95": "left hip", + "96": "right hip", + "97": "sacrum", + "98": "left gluteus maximus", + "99": "right gluteus maximus", + "100": "left gluteus medius", + "101": "right gluteus medius", + "102": "left gluteus minimus", + "103": "right gluteus minimus", + "104": "left autochthon", + "105": "right autochthon", + "106": "left iliopsoas", + "107": "right iliopsoas", + "108": "left atrial appendage", + "109": "brachiocephalic trunk", + "110": "left brachiocephalic vein", + "111": "right brachiocephalic vein", + "112": "left common carotid artery", + "113": "right common carotid artery", + "114": "costal cartilages", + "115": "heart", + "116": "left kidney cyst", + "117": "right kidney cyst", + "118": "prostate", + "119": "pulmonary vein", + "120": "skull", + "121": "spinal cord", + "122": "sternum", + "123": "left subclavian artery", + "124": "right subclavian artery", + "125": "superior vena cava", + "126": "thyroid gland", + "127": "vertebrae S1", + "128": "bone lesion", + "129": "kidney mass", + "130": "liver tumor", + "131": "vertebrae L6", + "132": "airway" + } + } + } + } +} diff --git a/models/vista3d/configs/mgpu_evaluate.json b/models/vista3d/configs/mgpu_evaluate.json new file mode 100644 index 00000000..f220c6f2 --- /dev/null +++ b/models/vista3d/configs/mgpu_evaluate.json @@ -0,0 +1,29 @@ +{ + "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])", + "network": { + "_target_": "torch.nn.parallel.DistributedDataParallel", + "module": "$@network_def.to(@device)", + "device_ids": [ + "@device" + ] + }, + "validate#sampler": { + "_target_": "DistributedSampler", + "dataset": "@validate#dataset", + "even_divisible": false, + "shuffle": false + }, + "validate#dataloader#sampler": "@validate#sampler", + "validate#handlers#1#_disabled_": "$dist.get_rank() > 0", + "initialize": [ + "$import torch.distributed as dist", + "$dist.is_initialized() or dist.init_process_group(backend='nccl')", + "$torch.cuda.set_device(@device)" + ], + "run": [ + "$@validate#evaluator.run()" + ], + "finalize": [ + "$dist.is_initialized() and dist.destroy_process_group()" + ] +} diff --git a/models/vista3d/configs/multi_gpu_train.json b/models/vista3d/configs/multi_gpu_train.json new file mode 100644 index 00000000..9fcc961e --- /dev/null +++ b/models/vista3d/configs/multi_gpu_train.json @@ -0,0 +1,42 @@ +{ + "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])", + "use_tensorboard": "$dist.get_rank() == 0", + "network": { + "_target_": "torch.nn.parallel.DistributedDataParallel", + "module": "$@network_def.to(@device)", + "find_unused_parameters": true, + "device_ids": [ + "@device" + ] + }, + "train#sampler": { + "_target_": "DistributedSampler", + "dataset": "@train#dataset", + "even_divisible": true, + "shuffle": true + }, + "train#dataloader#sampler": "@train#sampler", + "train#dataloader#shuffle": false, + "train#trainer#train_handlers": "$@train#handlers[: -1 if dist.get_rank() > 0 else None]", + "validate#sampler": { + "_target_": "DistributedSampler", + "dataset": "@validate#dataset", + "even_divisible": false, + "shuffle": false + }, + "validate#dataloader#sampler": "@validate#sampler", + "validate#evaluator#val_handlers": "$@validate#handlers[: -2 if dist.get_rank() > 0 else None]", + "initialize": [ + "$import torch.distributed as dist", + "$dist.is_initialized() or dist.init_process_group(backend='nccl')", + "$torch.cuda.set_device(@device)", + "$monai.utils.set_determinism(seed=123)" + ], + "run": [ + "$@validate#handlers#0.set_trainer(trainer=@train#trainer) if @early_stop else None", + "$@train#trainer.run()" + ], + "finalize": [ + "$dist.is_initialized() and dist.destroy_process_group()" + ] +} diff --git a/models/vista3d/configs/train.json b/models/vista3d/configs/train.json new file mode 100644 index 00000000..5996d8f5 --- /dev/null +++ b/models/vista3d/configs/train.json @@ -0,0 +1,394 @@ +{ + "imports": [ + "$import glob", + "$import os", + "$import scripts", + "$import ignite" + ], + "bundle_root": ".", + "ckpt_dir": "$@bundle_root + '/models'", + "output_dir": "$@bundle_root + '/eval'", + "data_list_file_path": "$@bundle_root + '/msd_task09_spleen_folds.json'", + "dataset_dir": "/data/Task09_Spleen", + "use_tensorboard": true, + "finetune": false, + "finetune_model_path": "$@bundle_root + '/models/model.pt'", + "early_stop": false, + "fold": 0, + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "epochs": 100, + "val_interval": 1, + "val_at_start": false, + "sw_overlap": 0.625, + "learning_rate": 0.0001, + "num_patches_per_image": 1, + "input_channels": 1, + "output_classes": 2, + "max_point": 5, + "max_prompt": null, + "max_backprompt": null, + "max_foreprompt": null, + "drop_label_prob": 0.5, + "drop_point_prob": 0.5, + "exclude_background": true, + "use_cfp": true, + "label_set": null, + "val_label_set": "@label_set", + "amp": true, + "train_datalist": "$monai.auto3dseg.utils.datafold_read(datalist=@data_list_file_path, basedir=@dataset_dir, fold=@fold)[0]", + "val_datalist": "$monai.auto3dseg.utils.datafold_read(datalist=@data_list_file_path, basedir=@dataset_dir, fold=@fold)[1]", + "patch_size": [ + 128, + 128, + 128 + ], + "patch_size_valid": "$@patch_size", + "network_def": "$scripts.vista3d.vista_model_registry['vista3d_segresnet_d'](in_channels=@input_channels, image_size=@patch_size)", + "network": "$@network_def.to(@device)", + "loss": { + "_target_": "DiceCELoss", + "include_background": true, + "sigmoid": true, + "smooth_dr": 1e-05, + "smooth_nr": 0, + "squared_pred": true, + "to_onehot_y": false + }, + "optimizer": { + "_target_": "torch.optim.AdamW", + "params": "$@network.parameters()", + "lr": "@learning_rate", + "weight_decay": 1e-05 + }, + "lr_schedule": { + "activate": true, + "lr_scheduler": { + "_target_": "monai.optimizers.WarmupCosineSchedule", + "optimizer": "@optimizer", + "t_total": "$@epochs", + "warmup_steps": 3, + "warmup_multiplier": 0.1 + } + }, + "resample_to_spacing": [ + 1.5, + 1.5, + 1.5 + ], + "train": { + "deterministic_transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image", + "label" + ], + "image_only": true, + "ensure_channel_first": true + }, + { + "_target_": "CropForegroundd", + "keys": [ + "image", + "label" + ], + "source_key": "image", + "margin": 10, + "allow_smaller": true, + "start_coord_key": null, + "end_coord_key": null + }, + { + "_target_": "ScaleIntensityRanged", + "keys": "image", + "a_min": -963.8247715525971, + "a_max": 1053.678477684517, + "b_min": 0.0, + "b_max": 1.0, + "clip": true + }, + { + "_target_": "Orientationd", + "keys": [ + "image", + "label" + ], + "axcodes": "RAS" + }, + { + "_target_": "Spacingd", + "keys": [ + "image", + "label" + ], + "pixdim": "$@resample_to_spacing", + "mode": [ + "bilinear", + "nearest" + ] + }, + { + "_target_": "CastToTyped", + "keys": [ + "image", + "label" + ], + "dtype": [ + "$torch.float32", + "$torch.uint8" + ] + }, + { + "_target_": "EnsureTyped", + "keys": [ + "image", + "label" + ], + "track_meta": true + }, + { + "_target_": "SpatialPadd", + "keys": [ + "image", + "label" + ], + "spatial_size": "@patch_size", + "mode": [ + "constant", + "constant" + ] + } + ], + "random_transforms": [ + { + "_target_": "RandCropByLabelClassesd", + "keys": [ + "image", + "label" + ], + "label_key": "label", + "num_classes": "@output_classes", + "spatial_size": "@patch_size", + "num_samples": "@num_patches_per_image", + "warn": false + }, + { + "_target_": "ResizeWithPadOrCropd", + "keys": [ + "image", + "label" + ], + "spatial_size": "@patch_size" + }, + { + "_target_": "RandScaleIntensityd", + "keys": "image", + "prob": 0.1, + "factors": 0.1 + }, + { + "_target_": "RandShiftIntensityd", + "keys": "image", + "prob": 0.1, + "offsets": 0.1 + } + ], + "inferer": { + "_target_": "SimpleInferer" + }, + "preprocessing": { + "_target_": "Compose", + "transforms": "$@train#deterministic_transforms + @train#random_transforms" + }, + "dataset": { + "_target_": "Dataset", + "data": "@train_datalist", + "transform": "@train#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@train#dataset", + "batch_size": 1, + "shuffle": true, + "num_workers": 4, + "pin_memory": true, + "persistent_workers": true + }, + "handlers": [ + { + "_target_": "CheckpointLoader", + "_disabled_": "$not @finetune", + "load_path": "@finetune_model_path", + "load_dict": { + "model": "@network" + } + }, + { + "_target_": "LrScheduleHandler", + "_disabled_": "$not @lr_schedule#activate", + "lr_scheduler": "@lr_schedule#lr_scheduler", + "print_lr": true + }, + { + "_target_": "ValidationHandler", + "validator": "@validate#evaluator", + "epoch_level": true, + "exec_at_start": "@val_at_start", + "interval": "@val_interval" + }, + { + "_target_": "TensorBoardStatsHandler", + "_disabled_": "$not @use_tensorboard", + "log_dir": "@output_dir", + "tag_name": "train_loss", + "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" + }, + { + "_target_": "StatsHandler", + "tag_name": "train_loss", + "name": "StatsHandler", + "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" + } + ], + "key_metric": { + "train_accuracy": { + "_target_": "ignite.metrics.Accuracy", + "output_transform": "$monai.handlers.from_engine(['pred', 'label'])" + } + }, + "trainer": { + "_target_": "scripts.trainer.Vista3dTrainer", + "max_epochs": "@epochs", + "device": "@device", + "train_data_loader": "@train#dataloader", + "network": "@network", + "loss_function": "@loss", + "optimizer": "@optimizer", + "inferer": "@train#inferer", + "key_train_metric": null, + "train_handlers": "@train#handlers", + "amp": "@amp", + "hyper_kwargs": { + "output_classes": "@output_classes", + "max_point": "@max_point", + "max_prompt": "@max_prompt", + "max_backprompt": "@max_backprompt", + "max_foreprompt": "@max_foreprompt", + "drop_label_prob": "@drop_label_prob", + "drop_point_prob": "@drop_point_prob", + "exclude_background": "@exclude_background", + "use_cfp": "@use_cfp", + "label_set": "@label_set", + "patch_size": "@patch_size", + "user_prompt": false + } + } + }, + "validate": { + "preprocessing": { + "_target_": "Compose", + "transforms": "$@train#deterministic_transforms" + }, + "postprocessing": { + "_target_": "Compose", + "transforms": [ + { + "_target_": "AsDiscreted", + "keys": "pred", + "threshold": 0.0 + } + ] + }, + "dataset": { + "_target_": "Dataset", + "data": "$@val_datalist", + "transform": "@validate#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@validate#dataset", + "batch_size": 1, + "shuffle": false, + "num_workers": 4 + }, + "inferer": { + "_target_": "scripts.inferer.Vista3dInferer", + "roi_size": "@patch_size_valid", + "overlap": "@sw_overlap", + "use_cfp": "@use_cfp" + }, + "handlers": [ + { + "_target_": "EarlyStopHandler", + "_disabled_": "$not @early_stop", + "trainer": null, + "patience": 2, + "score_function": "$scripts.score_function", + "min_delta": 0.01 + }, + { + "_target_": "TensorBoardStatsHandler", + "_disabled_": "$not @use_tensorboard", + "log_dir": "@output_dir", + "iteration_log": false + }, + { + "_target_": "StatsHandler", + "iteration_log": false, + "name": "StatsHandler" + }, + { + "_target_": "CheckpointSaver", + "save_dir": "@ckpt_dir", + "save_dict": { + "model": "@network" + }, + "save_key_metric": true, + "key_metric_filename": "model.pt" + } + ], + "key_metric": { + "val_mean_dice": { + "_target_": "MeanDice", + "include_background": false, + "output_transform": "$monai.handlers.from_engine(['pred', 'label'])", + "num_classes": "@output_classes" + } + }, + "additional_metrics": { + "val_accuracy": { + "_target_": "ignite.metrics.Accuracy", + "output_transform": "$monai.handlers.from_engine(['pred', 'label'])" + } + }, + "evaluator": { + "_target_": "scripts.evaluator.Vista3dEvaluator", + "device": "@device", + "val_data_loader": "@validate#dataloader", + "network": "@network", + "inferer": "@validate#inferer", + "postprocessing": "@validate#postprocessing", + "key_val_metric": "@validate#key_metric", + "additional_metrics": null, + "val_handlers": "@validate#handlers", + "amp": true, + "hyper_kwargs": { + "output_classes": "@output_classes", + "drop_label_prob": "@drop_label_prob", + "drop_point_prob": "@drop_point_prob", + "exclude_background": "@exclude_background", + "use_cfp": "@use_cfp", + "label_set": "@label_set", + "user_prompt": false + } + } + }, + "initialize": [ + "$monai.utils.set_determinism(seed=0)" + ], + "run": [ + "$@validate#handlers#0.set_trainer(trainer=@train#trainer) if @early_stop else None", + "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())", + "$@train#trainer.run()" + ] +} diff --git a/models/vista3d/configs/train_continual.json b/models/vista3d/configs/train_continual.json new file mode 100644 index 00000000..887c6cfe --- /dev/null +++ b/models/vista3d/configs/train_continual.json @@ -0,0 +1,109 @@ +{ + "data_list_file_path": "$@bundle_root + '/msd_task09_spleen_folds.json'", + "dataset_dir": "/data/Task09_Spleen", + "finetune": true, + "val_at_start": true, + "finetune_model_path": "$@bundle_root + '/models/model.pt'", + "n_train_samples": 10, + "n_val_samples": 10, + "val_interval": 40, + "learning_rate": 0.0001, + "lr_schedule#activate": false, + "loss#smooth_dr": 0.01, + "loss#smooth_nr": 0.0001, + "train_dataset_cache_rate": 1.0, + "val_dataset_cache_rate": 1.0, + "num_cache_workers": 4, + "label_mappings": { + "default": [ + [ + 1, + 2 + ], + [ + 2, + 254 + ] + ] + }, + "patch_size": [ + 160, + 160, + 160 + ], + "label_set": "$[0] + list(x[1] for x in @label_mappings#default)", + "val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)", + "num_classes": 255, + "output_classes": "$len(@label_set)", + "optimizer": { + "_target_": "Novograd", + "lr": "@learning_rate", + "params": "$@network.parameters()" + }, + "show_cache_progress": true, + "resample_to_spacing": [ + 1.5, + 1.5, + 1.5 + ], + "cache_cls_idx": { + "activate": true, + "indices_key": "$'label_cls_indices' if @cache_cls_idx#activate else None" + }, + "train#random_transforms": [ + { + "_target_": "ClassesToIndicesd", + "_disabled_": "$not @cache_cls_idx#activate", + "keys": "label", + "num_classes": "@num_classes", + "indices_postfix": "_cls_indices", + "max_samples_per_class": "$int(10 * @epochs)" + }, + { + "_target_": "RandCropByLabelClassesd", + "keys": [ + "image", + "label" + ], + "label_key": "label", + "num_classes": "@num_classes", + "spatial_size": "@patch_size", + "num_samples": "@num_patches_per_image", + "ratios": "$tuple(float(i>=0) for i in range(@num_classes))", + "indices_key": "$@cache_cls_idx#indices_key", + "warn": false + }, + { + "_target_": "scripts.monai_trans_utils.RelabelD", + "keys": "label", + "label_mappings": "@label_mappings", + "dtype": "$torch.uint8" + } + ], + "train#handlers#0#strict": false, + "train#dataset": { + "_target_": "CacheDataset", + "data": "$@train_datalist[:@n_train_samples]", + "transform": "@train#preprocessing", + "cache_rate": "@train_dataset_cache_rate", + "hash_as_key": true, + "num_workers": "@num_cache_workers", + "progress": "@show_cache_progress" + }, + "validate#dataset": { + "_target_": "CacheDataset", + "data": "$@val_datalist[:@n_val_samples]", + "transform": "@validate#preprocessing", + "cache_rate": "@val_dataset_cache_rate", + "hash_as_key": true, + "num_workers": "@num_cache_workers", + "progress": "@show_cache_progress" + }, + "validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]", + "valid_remap": { + "_target_": "scripts.monai_trans_utils.RelabelD", + "keys": "label", + "label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}", + "dtype": "$torch.uint8" + } +} diff --git a/models/vista3d/docs/README.md b/models/vista3d/docs/README.md new file mode 100644 index 00000000..92c12220 --- /dev/null +++ b/models/vista3d/docs/README.md @@ -0,0 +1,215 @@ +# Model Overview +Vista3D model train/inference pipeline + +## Training configuration +The training was performed with the following: +- GPU: at least 16GB GPU memory +- Actual Model Input: 128 x 128 x 128 +- AMP: True +- Optimizer: Adam +- Learning Rate: 1e-2 +- Loss: BCE loss and L1 loss + +## Data +Note that VISTA3D is trained from a huge collection of datasets and cannot be simply reproduced in this bundle. + +The spleen Task from the Medical Segmentation Decathalon is selected as an example to show how to do train, continuous learning and evaluate. Users can find more details on the datasets at http://medicaldecathlon.com/. + +To train with other datasets, users need to provide a json data split for training and continuous learning (`msd_task09_spleen_folds.json` is an example for reference). The data split should meet the following format with a 5-fold split ('testing' labels are optional): +``` +{ + "training": [ + {"image": "img0001.nii.gz", "label": "label0001.nii.gz", "fold": 0}, + {"image": "img0002.nii.gz", "label": "label0002.nii.gz", "fold": 2}, + ... + ], + "testing": [ + {"image": "img0003.nii.gz", "label": "label0003.nii.gz"}, + {"image": "img0004.nii.gz", "label": "label0004.nii.gz"}, + ... + ] +} +``` + +### Input +1 channel +- List of 3D CT patches + +### Output +In Training Mode: Training loss + +In Evaluation Mode: Segmentation + +#### Training Loss + +#### Validation Accuracy + +## MONAI Bundle Commands +In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle. The CLI supports flexible use cases, such as overriding configs at runtime and predefining arguments in a file. + +For more details usage instructions, visit the [MONAI Bundle Configuration Page](https://docs.monai.io/en/latest/config_syntax.html). + + +#### Execute training: + +``` +python -m monai.bundle run --config_file configs/train.json +``` + +Please note that if the default dataset path is not modified with the actual path in the bundle config files, you can also override it by using `--dataset_dir`: + +``` +python -m monai.bundle run --config_file configs/train.json --dataset_dir +``` + +#### Execute finetune: + +``` +python -m monai.bundle run --config_file configs/train.json --finetune True +``` + +Please note that the path of model weights is "/models/model.pt", you can also override it by using `--finetune_model_path`: + +``` +python -m monai.bundle run --config_file configs/train.json --finetune True --finetune_model_path +``` + +#### Enable early stop in training: + +``` +python -m monai.bundle run --config_file configs/train.json --early_stop True +``` + +#### Override the `train` config to execute multi-GPU training: + +``` +torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.json','configs/multi_gpu_train.json']" +``` + + +#### Execute continual learning +When finetuning with new class names, please update `configs/train_continual.json`'s `label_mappings` accordingly. + +The current label mapping `[[1, 2], [2, 254]]` indicates that training labels' class indices `1` and `2`, are mapped +to the VISTA model's class `2` and `254` respectively (format `[[src_class_0, dst_class_0], [src_class_1, dst_class_1], ...]`). +Since `254` is not used by VISTA, it is therefore indicating +training with a new class (the training label's class `2` will be trained as VISTA class `254`). + +`label_set` is used to identify the VISTA model classes for providing training prompts. +`val_label_set` is used to identify the original training label classes for computing foreground/background mask during validation. + +The default configs for both variables are derived from the `label_mappings` config and include `[0]`: +``` +"label_set": "$[0] + list(x[1] for x in @label_mappings#default)" +"val_label_set": "$[0] + list(x[0] for x in @label_mappings#default)" +``` + + +Single-GPU: +``` +python -m monai.bundle run \ + --config_file="['configs/train.json','configs/train_continual.json']" --epochs=320 --learning_rate=0.005 +``` + +Multi-GPU: +``` +torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ + --config_file="['configs/train.json','configs/train_continual.json','configs/multi_gpu_train.json']" --epochs=320 --learning_rate=0.005 +``` + +The patch size parameter is defined in `configs/train_continual.json`: `"patch_size": [160, 160, 160]`, and this works for the use cases +of extending the current model to segment a few novel classes. Finetuning all supported classes may require large GPU memory and carefully designed +multi-stage training processes. + +Changing `patch_size` to a smaller value such as `"patch_size": [128, 128, 128]` used in `configs/train.json` would reduce the training memory footprint. + +#### Execute evaluation +`n_train_samples` and `n_val_samples` are used to specify the number of samples to use for training and validation respectively. + +`configs/data.yaml` shows potential configurations for each specific dataset for evaluation. + +Single-GPU: +``` +python -m monai.bundle run \ + --config_file="['configs/train.json','configs/train_continual.json','configs/evaluate.json','configs/data.yaml']" +``` + +Multi-GPU: +``` +torchrun --nnodes=1 --nproc_per_node=8 -m monai.bundle run \ + --config_file="['configs/train.json','configs/train_continual.json','configs/evaluate.json','configs/mgpu_evaluate.json','configs/data.yaml']" +``` + + +#### Execute inference: +Notice the VISTA3d bundle requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation. +It also supports point click prompts for binary segmentation. User can provide both prompts at the same time. To segment an image, set the input_dict to +: +``` +"input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'label_prompt':[1]}]", +"input_dict": "$[{'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'points':[[138,245,18], [271,343,27]], 'point_labels':[1,0]}]" +``` +- The input_dict must contain the absolute path to the nii image file, and must contain at least one prompt. The keys are "label_prompt", "points" and "point_labels". +- label_prompt is in the format of [B], points is [1, N, 3], point_labels is [1, N]. B is number of foreground object. **B must be 1 if label_prompt and points are provided together** +- N is number of click points, 3 is x,y,z coordinates **IN THE ORIGINAL IMAGE SPACE**. The inferer only supports SINGLE OBJECT point click segmentatation. +- point_labels 0 means background, 1 means foreground, -1 means ignoring this point. +- label_prompt and points key can be missing, but cannot be missing at the same time. +- points and point_labels must pe provided together. +- The label_prompt can perform multiple foreground object segmentation, e.g. [2,3,4,5] means segment those classes. Point prompts must NOT be provided. +- For segment everything, use label_prompt: list(set([i+1 for i in range(132)]) - set([22, 23, 15, 25, 19, 2, 26, 27, 28, 29, 117])) +- The point prompts for "Kidney", "Lung", "Bone" (class index [2, 20, 21]) are not allowed since those prompts will be divided into sub-categories (e.g. left kidney and right kidney). Use point prompts for the sub-categories as defined in the inference.json. +``` +python -m monai.bundle run --config_file configs/inference.json +``` + +#### Execute batch inference for segmenting everything +``` +python -m monai.bundle run --config_file="['configs/inference.json', 'configs/batch_inference.json']" --input_dir="/data/Task09_Spleen/imagesTr" --output_dir="./eval_task09" +``` + +`configs/batch_inference.json` by default runs the segment everything workflow (classes defined by `everything_labels`) on all (`*.nii.gz`) files in `input_dir`. +This default is overridable by changing the input folder `input_dir`, or the input image name suffix `input_suffix`, or directly setting the list of filenames `input_list`. + +Set `"postprocessing#transforms#0#_disabled_": false` to move the postprocessing to cpu to reduce the GPU memory footprint. + +## Automatic segmentation label prompts : +The mapping between organ name and label prompt is in the [json file](labels.json) + + +## Fast Point Window Inference: +When user click a point, there is no need to perform whole image sliding window inference. Set "use_point_window" to true in the inference.json to enable this function. +A window centered at the clicked points will be used for inference. All values outside of the window will set to be "NaN" unless "prev_mask" is passed to the inferer. +If no point click exists, this function will not be used. Notice if "use_point_window" is true and user provided point clicks, there will be obvious cut-off box artefacts. + +# References +- Roth, H., Farag, A., Turkbey, E. B., Lu, L., Liu, J., & Summers, R. M. (2016). Data From Pancreas-CT (Version 2) [Data set]. The Cancer Imaging Archive. https://doi.org/10.7937/K9/TCIA.2016.tNB1kqBU + +- J. Ma et al., "AbdomenCT-1K: Is Abdominal Organ Segmentation a Solved Problem?," in IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 44, no. 10, pp. 6695-6714, 1 Oct. 2022, doi: 10.1109/TPAMI.2021.3100536. + +- JI YUANFENG. (2022). Amos: A large-scale abdominal multi-organ benchmark for versatile medical image segmentation [Data set]. Zenodo. https://doi.org/10.5281/zenodo.7155725 + +- Antonelli, M., Reinke, A., Bakas, S. et al. The Medical Segmentation Decathlon. Nat Commun 13, 4128 (2022). https://doi.org/10.1038/s41467-022-30695-9 + +- Rister, B., Yi, D., Shivakumar, K. et al. CT-ORG, a new dataset for multiple organ segmentation in computed tomography. Sci Data 7, 381 (2020). https://doi.org/10.1038/s41597-020-00715-8 + +- Jakob Wasserthal. (2022). Dataset with segmentations of 104 important anatomical structures in 1204 CT images (1.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.6802614 + +- Gibson, E., Giganti, F., Hu, Y., Bonmati, E., Bandula, S., Gurusamy, K., Davidson, B., Pereira, S. P., Clarkson, M. J., & Barratt, D. C. (2018). Multi-organ Abdominal CT Reference Standard Segmentations (1.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.1169361 + +- Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge https://www.synapse.org/#!Synapse:syn3193805/wiki/217753 + + +# License +Copyright (c) MONAI Consortium + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/models/vista3d/docs/data_license.txt b/models/vista3d/docs/data_license.txt new file mode 100644 index 00000000..5cffccb1 --- /dev/null +++ b/models/vista3d/docs/data_license.txt @@ -0,0 +1,6 @@ +Third Party Licenses +----------------------------------------------------------------------- + +/*********************************************************************/ +i. Medical Segmentation Decathlon + http://medicaldecathlon.com/ diff --git a/models/vista3d/docs/labels.json b/models/vista3d/docs/labels.json new file mode 100644 index 00000000..dcdd73dd --- /dev/null +++ b/models/vista3d/docs/labels.json @@ -0,0 +1,137 @@ +{ + "liver": 1, + "kidney": 2, + "spleen": 3, + "pancreas": 4, + "right kidney": 5, + "aorta": 6, + "inferior vena cava": 7, + "right adrenal gland": 8, + "left adrenal gland": 9, + "gallbladder": 10, + "esophagus": 11, + "stomach": 12, + "duodenum": 13, + "left kidney": 14, + "bladder": 15, + "prostate or uterus": 16, + "portal vein and splenic vein": 17, + "rectum": 18, + "small bowel": 19, + "lung": 20, + "bone": 21, + "brain": 22, + "lung tumor": 23, + "pancreatic tumor": 24, + "hepatic vessel": 25, + "hepatic tumor": 26, + "colon cancer primaries": 27, + "left lung upper lobe": 28, + "left lung lower lobe": 29, + "right lung upper lobe": 30, + "right lung middle lobe": 31, + "right lung lower lobe": 32, + "vertebrae L5": 33, + "vertebrae L4": 34, + "vertebrae L3": 35, + "vertebrae L2": 36, + "vertebrae L1": 37, + "vertebrae T12": 38, + "vertebrae T11": 39, + "vertebrae T10": 40, + "vertebrae T9": 41, + "vertebrae T8": 42, + "vertebrae T7": 43, + "vertebrae T6": 44, + "vertebrae T5": 45, + "vertebrae T4": 46, + "vertebrae T3": 47, + "vertebrae T2": 48, + "vertebrae T1": 49, + "vertebrae C7": 50, + "vertebrae C6": 51, + "vertebrae C5": 52, + "vertebrae C4": 53, + "vertebrae C3": 54, + "vertebrae C2": 55, + "vertebrae C1": 56, + "trachea": 57, + "left iliac artery": 58, + "right iliac artery": 59, + "left iliac vena": 60, + "right iliac vena": 61, + "colon": 62, + "left rib 1": 63, + "left rib 2": 64, + "left rib 3": 65, + "left rib 4": 66, + "left rib 5": 67, + "left rib 6": 68, + "left rib 7": 69, + "left rib 8": 70, + "left rib 9": 71, + "left rib 10": 72, + "left rib 11": 73, + "left rib 12": 74, + "right rib 1": 75, + "right rib 2": 76, + "right rib 3": 77, + "right rib 4": 78, + "right rib 5": 79, + "right rib 6": 80, + "right rib 7": 81, + "right rib 8": 82, + "right rib 9": 83, + "right rib 10": 84, + "right rib 11": 85, + "right rib 12": 86, + "left humerus": 87, + "right humerus": 88, + "left scapula": 89, + "right scapula": 90, + "left clavicula": 91, + "right clavicula": 92, + "left femur": 93, + "right femur": 94, + "left hip": 95, + "right hip": 96, + "sacrum": 97, + "left gluteus maximus": 98, + "right gluteus maximus": 99, + "left gluteus medius": 100, + "right gluteus medius": 101, + "left gluteus minimus": 102, + "right gluteus minimus": 103, + "left autochthon": 104, + "right autochthon": 105, + "left iliopsoas": 106, + "right iliopsoas": 107, + "left atrial appendage": 108, + "brachiocephalic trunk": 109, + "left brachiocephalic vein": 110, + "right brachiocephalic vein": 111, + "left common carotid artery": 112, + "right common carotid artery": 113, + "costal cartilages": 114, + "heart": 115, + "left kidney cyst": 116, + "right kidney cyst": 117, + "prostate": 118, + "pulmonary vein": 119, + "skull": 120, + "spinal cord": 121, + "sternum": 122, + "left subclavian artery": 123, + "right subclavian artery": 124, + "superior vena cava": 125, + "thyroid gland": 126, + "vertebrae S1": 127, + "bone lesion": 128, + "kidney mass": 129, + "liver tumor": 130, + "vertebrae L6": 131, + "airway": 132, + "FDG-avid lesion": 133, + "lung nodule": 134, + "lumbar spine": 135 +} diff --git a/models/vista3d/large_files.yml b/models/vista3d/large_files.yml new file mode 100644 index 00000000..6a1727bf --- /dev/null +++ b/models/vista3d/large_files.yml @@ -0,0 +1,3 @@ +large_files: + - path: "models/model.pt" + url: "https://drive.google.com/file/d/1eLIxQwnxGsjggxiVjdcAyNvJ5DYtqmdc/view?usp=sharing" diff --git a/models/vista3d/msd_task09_spleen_folds.json b/models/vista3d/msd_task09_spleen_folds.json new file mode 100644 index 00000000..30d0f353 --- /dev/null +++ b/models/vista3d/msd_task09_spleen_folds.json @@ -0,0 +1,271 @@ +{ + "testing": [ + { + "image": "imagesTs/spleen_15.nii.gz" + }, + { + "image": "imagesTs/spleen_23.nii.gz" + }, + { + "image": "imagesTs/spleen_1.nii.gz" + }, + { + "image": "imagesTs/spleen_42.nii.gz" + }, + { + "image": "imagesTs/spleen_50.nii.gz" + }, + { + "image": "imagesTs/spleen_54.nii.gz" + }, + { + "image": "imagesTs/spleen_37.nii.gz" + }, + { + "image": "imagesTs/spleen_58.nii.gz" + }, + { + "image": "imagesTs/spleen_39.nii.gz" + }, + { + "image": "imagesTs/spleen_48.nii.gz" + }, + { + "image": "imagesTs/spleen_35.nii.gz" + }, + { + "image": "imagesTs/spleen_11.nii.gz" + }, + { + "image": "imagesTs/spleen_7.nii.gz" + }, + { + "image": "imagesTs/spleen_30.nii.gz" + }, + { + "image": "imagesTs/spleen_43.nii.gz" + }, + { + "image": "imagesTs/spleen_51.nii.gz" + }, + { + "image": "imagesTs/spleen_36.nii.gz" + }, + { + "image": "imagesTs/spleen_55.nii.gz" + }, + { + "image": "imagesTs/spleen_57.nii.gz" + }, + { + "image": "imagesTs/spleen_34.nii.gz" + } + ], + "training": [ + { + "fold": 0, + "image": "imagesTr/spleen_19.nii.gz", + "label": "labelsTr/spleen_19.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_31.nii.gz", + "label": "labelsTr/spleen_31.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_52.nii.gz", + "label": "labelsTr/spleen_52.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_40.nii.gz", + "label": "labelsTr/spleen_40.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_3.nii.gz", + "label": "labelsTr/spleen_3.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_17.nii.gz", + "label": "labelsTr/spleen_17.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_21.nii.gz", + "label": "labelsTr/spleen_21.nii.gz" + }, + { + "fold": 0, + "image": "imagesTr/spleen_33.nii.gz", + "label": "labelsTr/spleen_33.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_9.nii.gz", + "label": "labelsTr/spleen_9.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_29.nii.gz", + "label": "labelsTr/spleen_29.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_46.nii.gz", + "label": "labelsTr/spleen_46.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_25.nii.gz", + "label": "labelsTr/spleen_25.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_13.nii.gz", + "label": "labelsTr/spleen_13.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_62.nii.gz", + "label": "labelsTr/spleen_62.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_27.nii.gz", + "label": "labelsTr/spleen_27.nii.gz" + }, + { + "fold": 1, + "image": "imagesTr/spleen_44.nii.gz", + "label": "labelsTr/spleen_44.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_56.nii.gz", + "label": "labelsTr/spleen_56.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_60.nii.gz", + "label": "labelsTr/spleen_60.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_2.nii.gz", + "label": "labelsTr/spleen_2.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_53.nii.gz", + "label": "labelsTr/spleen_53.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_41.nii.gz", + "label": "labelsTr/spleen_41.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_22.nii.gz", + "label": "labelsTr/spleen_22.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_14.nii.gz", + "label": "labelsTr/spleen_14.nii.gz" + }, + { + "fold": 2, + "image": "imagesTr/spleen_18.nii.gz", + "label": "labelsTr/spleen_18.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_20.nii.gz", + "label": "labelsTr/spleen_20.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_32.nii.gz", + "label": "labelsTr/spleen_32.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_16.nii.gz", + "label": "labelsTr/spleen_16.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_12.nii.gz", + "label": "labelsTr/spleen_12.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_63.nii.gz", + "label": "labelsTr/spleen_63.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_28.nii.gz", + "label": "labelsTr/spleen_28.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_24.nii.gz", + "label": "labelsTr/spleen_24.nii.gz" + }, + { + "fold": 3, + "image": "imagesTr/spleen_59.nii.gz", + "label": "labelsTr/spleen_59.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_47.nii.gz", + "label": "labelsTr/spleen_47.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_8.nii.gz", + "label": "labelsTr/spleen_8.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_6.nii.gz", + "label": "labelsTr/spleen_6.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_61.nii.gz", + "label": "labelsTr/spleen_61.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_10.nii.gz", + "label": "labelsTr/spleen_10.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_38.nii.gz", + "label": "labelsTr/spleen_38.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_45.nii.gz", + "label": "labelsTr/spleen_45.nii.gz" + }, + { + "fold": 4, + "image": "imagesTr/spleen_26.nii.gz", + "label": "labelsTr/spleen_26.nii.gz" + }, + { + "image": "imagesTr/spleen_49.nii.gz", + "label": "labelsTr/spleen_49.nii.gz", + "fold": 0 + } + ] +} diff --git a/models/vista3d/scripts/__init__.py b/models/vista3d/scripts/__init__.py new file mode 100644 index 00000000..b1765e41 --- /dev/null +++ b/models/vista3d/scripts/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator +# from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer + +from . import vista3d +from .early_stop_score_function import score_function diff --git a/models/vista3d/scripts/early_stop_score_function.py b/models/vista3d/scripts/early_stop_score_function.py new file mode 100644 index 00000000..350f3ffe --- /dev/null +++ b/models/vista3d/scripts/early_stop_score_function.py @@ -0,0 +1,15 @@ +import os + +import torch +import torch.distributed as dist + + +def score_function(engine): + val_metric = engine.state.metrics["val_mean_dice"] + if dist.is_initialized(): + device = torch.device("cuda:" + os.environ["LOCAL_RANK"]) + val_metric = torch.tensor([val_metric]).to(device) + dist.all_reduce(val_metric, op=dist.ReduceOp.SUM) + val_metric /= dist.get_world_size() + return val_metric.item() + return val_metric diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py new file mode 100644 index 00000000..669073b6 --- /dev/null +++ b/models/vista3d/scripts/evaluator.py @@ -0,0 +1,292 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence + +import numpy as np +import torch +from monai.config import IgniteInfo +from monai.engines.evaluator import SupervisedEvaluator +from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch +from monai.inferers import Inferer, SimpleInferer +from monai.transforms import Transform +from monai.utils import ForwardMode, RankFilter, min_version, optional_import +from monai.utils.enums import CommonKeys as Keys +from torch.utils.data import DataLoader + +rearrange, _ = optional_import("einops", name="rearrange") + +if TYPE_CHECKING: + from ignite.engine import Engine, EventEnum + from ignite.metrics import Metric +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") + +__all__ = ["Vista3dEvaluator"] + + +class Vista3dEvaluator(SupervisedEvaluator): + """ + Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``. + Args: + device: an object representing the device on which to run. + val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader. + network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`. + epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch: function to parse expected data (usually `image`, `label` and other network args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. + iteration_update: the callable function for every iteration, expect to accept `engine` + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. + inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. + postprocessing: execute additional transformation for the model output data. + Typically, several Tensor based transforms composed by `Compose`. + key_val_metric: compute metric when every iteration completed, and save average value to + engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the + checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. + val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision evaluation, default is False. + mode: model forward mode during evaluation, should be 'eval' or 'train', + which maps to `model.eval()` or `model.train()`, default to 'eval'. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, + recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. + default to `True`. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device, + val_data_loader: Iterable | DataLoader, + network: torch.nn.Module, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_val_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + val_handlers: Sequence | None = None, + amp: bool = False, + mode: ForwardMode | str = ForwardMode.EVAL, + event_names: list[str | EventEnum | type[EventEnum]] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + hyper_kwargs: dict | None = None, + ) -> None: + super().__init__( + device=device, + val_data_loader=val_data_loader, + network=network, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_val_metric=key_val_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + val_handlers=val_handlers, + amp=amp, + mode=mode, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.network = network + self.device = device + self.inferer = SimpleInferer() if inferer is None else inferer + self.hyper_kwargs = hyper_kwargs + self.logger.addFilter(RankFilter()) + + def transform_points(self, point, affine): + """transform point to the coordinates of the transformed image + point: numpy array [bs, N, 3] + """ + bs, n = point.shape[:2] + point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1) + point = rearrange(point, "b n d -> d (b n)") + point = affine @ point + point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] + return point + + def check_prompts_format(self, label_prompt, points, point_labels): + """check the format of user prompts + label_prompt: [1,2,3,4,...,B] List of tensors + points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object + point_labels: [[1,1,0,...]] List of scalar that matches number of points + """ + # check prompt is given + if label_prompt is None and points is None: + everything_labels = self.hyper_kwargs.get("everything_labels", None) + if everything_labels is not None: + label_prompt = [torch.tensor(_) for _ in everything_labels] + return label_prompt, points, point_labels + else: + raise ValueError("Prompt must be given for inference.") + # check label_prompt + if label_prompt is not None: + if isinstance(label_prompt, list): + if not np.all([len(_) == 1 for _ in label_prompt]): + raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].") + if not np.all([(x < 255).item() for x in label_prompt]): + raise ValueError("Current bundle only supports label prompt smaller than 255.") + else: + raise ValueError("Label prompt must be a list, [1,2,3,4,...,].") + # check points + if points is not None: + if point_labels is None: + raise ValueError("Point labels must be given if points are given.") + if not np.all([len(_) == 3 for _ in points]): + raise ValueError("Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].") + if len(points) != len(point_labels): + raise ValueError("Points must match point labels.") + if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]): + raise ValueError("Point labels can only be -1,0,1 and 2,3 for special flags.") + if label_prompt is not None and points is not None: + if len(label_prompt) != 1: + raise ValueError("Label prompt can only be a single object if provided with point prompts.") + # check point_labels + if point_labels is not None: + if points is None: + raise ValueError("Points must be given if point labels are given.") + return label_prompt, points, point_labels + + def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict: + """ + callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + - LABEL: label Tensor data corresponding to the image, already moved to device. + - PRED: prediction result of model. + + Args: + engine: `SupervisedEvaluator` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + + Raises: + ValueError: When ``batchdata`` is None. + + """ + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + label_set = engine.hyper_kwargs.get("label_set", None) + # If user provide prompts in the inference, input image must contain original affine. + # the point coordinates are from the original_affine space, while image here is after preprocess transforms. + if engine.hyper_kwargs["user_prompt"]: + inputs, label_prompt, points, point_labels = ( + batchdata["image"], + batchdata.get("label_prompt", None), + batchdata.get("points", None), + batchdata.get("point_labels", None), + ) + labels = None + label_prompt, points, point_labels = self.check_prompts_format(label_prompt, points, point_labels) + inputs = inputs.to(engine.device) + # For N foreground object, label_prompt is [1, N], but the batch number 1 needs to be removed. Convert to [N, 1] + label_prompt = ( + torch.as_tensor([label_prompt]).to(inputs.device)[0].unsqueeze(-1) if label_prompt is not None else None + ) + # For points, the size can only be [1, K, 3], where K is the number of points for this single foreground object. + if points is not None: + points = torch.as_tensor([points]) + points = self.transform_points( + points, np.linalg.inv(inputs.affine[0]) @ inputs.meta["original_affine"][0].numpy() + ) + points = torch.from_numpy(points).to(inputs.device) + point_labels = torch.as_tensor([point_labels]).to(inputs.device) if point_labels is not None else None + + # If validation with ground truth label available. + else: + inputs, labels = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) + # create label prompt, this should be consistent with the label prompt used for training. + if label_set is None: + output_classes = engine.hyper_kwargs["output_classes"] + label_set = np.arange(output_classes).tolist() + label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1) + # point prompt is generated withing vista3d,provide empty points + points = torch.zeros(label_prompt.shape[0], 1, 3) + point_labels = -1 + torch.zeros(label_prompt.shape[0], 1) + if engine.hyper_kwargs["drop_point_prob"] > 0.99: + # automatic only validation + points = None + point_labels = None + if engine.hyper_kwargs["drop_label_prob"] > 0.99: + # point only validation + label_prompt = None + # this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points + val_label_set = engine.hyper_kwargs.get("val_label_set", label_set) + + # put iteration outputs into engine.state + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels} + # execute forward computation + with engine.mode(engine.network): + if engine.amp: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + engine.state.output[Keys.PRED] = engine.inferer( + inputs=inputs, + network=engine.network, + point_coords=points, + point_labels=point_labels, + class_vector=label_prompt, + labels=labels, + label_set=val_label_set, + ) + else: + engine.state.output[Keys.PRED] = engine.inferer( + inputs=inputs, + network=engine.network, + point_coords=points, + point_labels=point_labels, + class_vector=label_prompt, + labels=labels, + label_set=val_label_set, + ) + # Add dim 0 for decollate batch + engine.state.output["label_prompt"] = label_prompt.unsqueeze(0) if label_prompt is not None else None + engine.state.output["points"] = points.unsqueeze(0) if points is not None else None + engine.state.output["point_labels"] = point_labels.unsqueeze(0) if point_labels is not None else None + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + engine.fire_event(IterationEvents.MODEL_COMPLETED) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return engine.state.output diff --git a/models/vista3d/scripts/inferer.py b/models/vista3d/scripts/inferer.py new file mode 100644 index 00000000..98fc95a1 --- /dev/null +++ b/models/vista3d/scripts/inferer.py @@ -0,0 +1,132 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Union + +import torch +from monai.inferers.inferer import Inferer +from torch import Tensor + +from .monai_utils import sliding_window_inference +from .utils import point_based_window_inferer + + +class Vista3dInferer(Inferer): + """ + Vista3D Inferer + + Args: + roi_size: the sliding window patch size. + overlap: sliding window overlap ratio. + use_cfp: use class prompt for point head. + """ + + def __init__(self, roi_size, overlap, use_cfp, use_point_window=False, sw_batch_size=1) -> None: + Inferer.__init__(self) + self.roi_size = roi_size + self.overlap = overlap + self.use_cfp = use_cfp + self.sw_batch_size = sw_batch_size + self.use_point_window = use_point_window + self.sliding_window_inferer = point_based_window_inferer if use_point_window else sliding_window_inference + + def __call__( + self, + inputs: Union[List[Tensor], Tensor], + network, + point_coords, + point_labels, + class_vector, + labels=None, + label_set=None, + prev_mask=None, + ): + """ + Unified callable function API of Inferers. + Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1. + It only used in interactive segmentation. + + Args: + inputs: input tensor images. + network: vista3d model. + point_coords: point click coordinates. [B, N, 3]. + point_labels: point click labels (0 for negative, 1 for positive) [B, N]. + class_vector: class vector of length B. + labels: groundtruth labels. Used for sampling validation points. + label_set: [0,1,2,3,...,output_classes]. + prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! + + """ + sliding_window_inferer = ( + point_based_window_inferer + if (self.use_point_window and point_coords is not None) + else sliding_window_inference + ) + prompt_class = copy.deepcopy(class_vector) + if class_vector is not None: + # Check if network has attribute 'point_head' directly or within its 'module' + if hasattr(network, "point_head"): + point_head = network.point_head + elif hasattr(network, "module") and hasattr(network.module, "point_head"): + point_head = network.module.point_head + else: + raise AttributeError("Network does not have attribute 'point_head'.") + + if torch.any(class_vector > point_head.last_supported): + class_vector = None + if isinstance(inputs, list): + device = inputs[0].device + else: + device = inputs.device + try: + val_outputs = sliding_window_inferer( + inputs=inputs, + roi_size=self.roi_size, + sw_batch_size=self.sw_batch_size, + transpose=True, + predictor=network, + mode="gaussian", + sw_device=device, + device=device, + overlap=self.overlap, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + prev_mask=prev_mask, + labels=labels, + label_set=label_set, + use_cfp=self.use_cfp, + ) + except Exception: + val_outputs = None + torch.cuda.empty_cache() + val_outputs = sliding_window_inferer( + inputs=inputs, + roi_size=self.roi_size, + sw_batch_size=self.sw_batch_size, + transpose=True, + predictor=network, + mode="gaussian", + sw_device=device, + device="cpu", + overlap=self.overlap, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + prev_mask=prev_mask, + labels=labels, + label_set=label_set, + use_cfp=self.use_cfp, + ) + return val_outputs diff --git a/models/vista3d/scripts/monai_trans_utils.py b/models/vista3d/scripts/monai_trans_utils.py new file mode 100644 index 00000000..a66816cf --- /dev/null +++ b/models/vista3d/scripts/monai_trans_utils.py @@ -0,0 +1,317 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +from collections.abc import Hashable, Mapping + +import numpy as np +import torch +from monai.config import DtypeLike, KeysCollection +from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.transforms import MapLabelValue +from monai.transforms.transform import MapTransform +from monai.utils import look_up_option, min_version, optional_import +from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type + +measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) +cp, has_cp = optional_import("cupy") + + +def get_largest_connected_component_point(img: NdarrayTensor, point_coords=None, point_labels=None) -> NdarrayTensor: + """ + Gets the largest connected component mask of an image. img is before post process! And will include NaN values. + Args: + img: [1, B, H, W, D] + point_coords [B, N, 3] + point_labels [B, N] + """ + outs = torch.zeros_like(img) + for c in range(len(point_coords)): + if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()): + continue + coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist() + not_nan_mask = ~torch.isnan(img[0, c]) + img_ = torch.nan_to_num(img[0, c] > 0, 0) + img_, *_ = convert_data_type(img_, np.ndarray) + label = measure.label + features = label(img_, connectivity=3) + pos_mask = torch.from_numpy(img_).to(img.device) > 0 + # if num features less than max desired, nothing to do. + features = torch.from_numpy(features).to(img.device) + # generate a map with all pos points + idx = [] + for p in coords: + idx.append(features[round(p[0]), round(p[1]), round(p[2])].item()) + idx = list(set(idx)) + for i in idx: + if i == 0: + continue + outs[0, c] += features == i + outs = outs > 0 + # find negative mean value + fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean() + img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in + return img + + +def get_largest_connected_component_mask( + img_pos: NdarrayTensor, + img_neg: NdarrayTensor, + connectivity: int | None = None, + num_components: int = 1, + point_coords=None, + point_labels=None, + margins=3, +) -> NdarrayTensor: + """ + Gets the largest connected component mask of an image. + + Args: + img: Image to get largest connected component from. Shape is (spatial_dim1 [, spatial_dim2, ...]) + connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. If ``None``, a full + connectivity of ``input.ndim`` is used. for more details: + https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label. + num_components: The number of largest components to preserve. + """ + # use skimage/cucim.skimage and np/cp depending on whether packages are + # available and input is non-cpu torch.tensor + cucim_skimage, has_cucim = optional_import("cucim.skimage") + + use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu") + if use_cp: + img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore + img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore + label = cucim_skimage.measure.label + lib = cp + else: + if not has_measure: + raise RuntimeError("Skimage.measure required.") + img_pos_, *_ = convert_data_type(img_pos, np.ndarray) + img_neg_, *_ = convert_data_type(img_neg, np.ndarray) + label = measure.label + lib = np + + # features will be an image -- 0 for background and then each different + # feature will have its own index. + # features, num_features = label(img_, connectivity=connectivity, return_num=True) + + features_pos, num_features = label(img_pos_, connectivity=3, return_num=True) + features_neg, num_features = label(img_neg_, connectivity=3, return_num=True) + + # if num features less than max desired, nothing to do. + outs = np.zeros_like(img_pos_) + for bs in range(point_coords.shape[0]): + for i, p in enumerate(point_coords[bs]): + if point_labels[bs, i] == 1 or point_labels[bs, i] == 3: + features = features_pos + elif point_labels[bs, i] == 0 or point_labels[bs, i] == 2: + features = features_neg + else: + # if -1 padding point, skip + continue + p = p.round().int() + for margin in range(margins): + l, r = max(p[0].item() - margin, 0), min(p[0].item() + margin + 1, features.shape[-3]) + t, d = max(p[1].item() - margin, 0), min(p[1].item() + margin + 1, features.shape[-2]) + f, b = max(p[2].item() - margin, 0), min(p[2].item() + margin + 1, features.shape[-1]) + index = features_pos[bs, 0, l:r, t:d, f:b].max() + if index > 0: + outs[[bs]] += lib.isin(features[[bs]], index) + break + outs[outs > 1] = 1 + outs = convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0] + return outs + + +class VistaPostTransform(MapTransform): + def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + for keys in self.keys: + if keys in data: + pred = data[keys] + object_num = pred.shape[0] + device = pred.device + if data.get("label_prompt", None) is None and data.get("points", None) is not None: + pred = get_largest_connected_component_point( + pred.unsqueeze(0), + point_coords=data.get("points").to(device), + point_labels=data.get("point_labels").to(device), + )[0] + pred[pred < 0] = 0.0 + # if it's multichannel, perform argmax + if object_num > 1: + # concate background channel. Make sure user did not provide 0 as prompt. + is_bk = torch.all(pred <= 0, dim=0, keepdim=True) + pred = pred.argmax(0).unsqueeze(0).float() + 1.0 + pred[is_bk] = 0.0 + else: + # AsDiscrete will remove NaN + # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred) + pred[pred > 0] = 1.0 + if "label_prompt" in data and data["label_prompt"] is not None: + pred += 0.5 # inplace mapping to avoid cloning pred + for i in range(1, object_num + 1): + frac = i + 0.5 + pred[pred == frac] = data["label_prompt"][i - 1].to(pred.dtype) + pred[pred == 0.5] = 0.0 + data[keys] = pred + return data + + +def get_name_to_index_mapping(bundle_root): + """get the label name to index mapping""" + name_to_index_mapping = {} + metadata_path = os.path.join(bundle_root, "configs/metadata.json") + if not os.path.isfile(metadata_path): + return name_to_index_mapping + with open(metadata_path, "r") as f: + metadata = json.load(f) + labels = metadata.get("network_data_format", {}).get("outputs", {}).get("pred", {}).get("channel_def") + if labels is None: + return name_to_index_mapping + name_to_index_mapping = {v.lower(): int(k) for k, v in labels.items()} + return name_to_index_mapping + + +def convert_name_to_index(name_to_index_mapping, label_prompt): + """convert the label name to index""" + if label_prompt is not None and isinstance(label_prompt, list): + converted_label_prompt = [] + # for new class, add to the mapping + for l in label_prompt: + if isinstance(l, str) and not l.isdigit(): + if l.lower() not in name_to_index_mapping: + name_to_index_mapping[l.lower()] = len(name_to_index_mapping) + for l in label_prompt: + if isinstance(l, (int, str)): + converted_label_prompt.append( + name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l) + ) + else: + converted_label_prompt.append(l) + return converted_label_prompt + return label_prompt + + +class VistaPreTransform(MapTransform): + def __init__( + self, + keys: KeysCollection, + allow_missing_keys: bool = False, + special_index=(25, 26, 27, 28, 29, 117), + subclass=None, + bundle_root=None, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + dataset_transforms: a dictionary specifies the transform for corresponding dataset: + key: dataset name, value: list of data transforms. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + special_index: the class index that need to be handled differently. + """ + super().__init__(keys, allow_missing_keys) + self.special_index = special_index + self.subclass = subclass + self.name_to_index_mapping = get_name_to_index_mapping(bundle_root) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + label_prompt = data.get("label_prompt", None) + point_labels = data.get("point_labels", None) + # convert the label name to index if needed + label_prompt = convert_name_to_index(self.name_to_index_mapping, label_prompt) + try: + # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator. + if self.subclass is not None and label_prompt is not None: + _label_prompt = [] + subclass_keys = list(map(int, self.subclass.keys())) + for i in range(len(label_prompt)): + if label_prompt[i] in subclass_keys: + _label_prompt.extend(self.subclass[str(label_prompt[i])]) + else: + _label_prompt.append(label_prompt[i]) + data["label_prompt"] = _label_prompt + + if label_prompt is not None and point_labels is not None: + if label_prompt[0] in self.special_index: + point_labels = np.array(point_labels) + point_labels[point_labels == 0] = 2 + point_labels[point_labels == 1] = 3 + point_labels = point_labels.tolist() + data["point_labels"] = point_labels + except Exception: + pass + + return data + + +class RelabelD(MapTransform): + def __init__( + self, + keys: KeysCollection, + label_mappings: dict[str, list[tuple[int, int]]], + dtype: DtypeLike = np.int16, + dataset_key: str = "dataset_name", + allow_missing_keys: bool = False, + ) -> None: + """ + label_mappings[data[dataset_key]] should has the format: [(local label, global label), ...] + + This list of local -> global label mappings will be applied to each input `data[keys]`. + if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used. + if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed. + + Args: + keys: keys of the corresponding items to be transformed. + label_mappings: a dictionary specifies how local dataset class indices are mapped to the + global class indices, format: + key: dataset name, value: list of (local label, global label) pairs + set `label_mappings={}` to completely skip this transform. + dtype: convert the output data to dtype, default to float32. + dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name". + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.mappers = {} + self.dataset_key = dataset_key + for name, mapping in label_mappings.items(): + self.mappers[name] = MapLabelValue( + orig_labels=[int(pair[0]) for pair in mapping], + target_labels=[int(pair[1]) for pair in mapping], + dtype=dtype, + ) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + dataset_name = d.get(self.dataset_key, "default") + _m = look_up_option(dataset_name, self.mappers, default=None) + if _m is None: + return d + for key in self.key_iterator(d): + d[key] = _m(d[key]) + return d diff --git a/models/vista3d/scripts/monai_utils.py b/models/vista3d/scripts/monai_utils.py new file mode 100644 index 00000000..bfe4fd2c --- /dev/null +++ b/models/vista3d/scripts/monai_utils.py @@ -0,0 +1,412 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import itertools +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Iterable + +import numpy as np +import torch +import torch.nn.functional as F +from monai.data.meta_tensor import MetaTensor +from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size +from monai.utils import ( + BlendMode, + PytorchPadMode, + convert_data_type, + convert_to_dst_type, + ensure_tuple, + ensure_tuple_rep, + fall_back_tuple, + look_up_option, + optional_import, + pytorch_after, +) + +tqdm, _ = optional_import("tqdm", name="tqdm") +_nearest_mode = "nearest-exact" if pytorch_after(1, 11) else "nearest" + +__all__ = ["sliding_window_inference"] + + +def sliding_window_inference( + inputs: torch.Tensor | MetaTensor, + roi_size: Sequence[int] | int, + sw_batch_size: int, + predictor: Callable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]], + overlap: Sequence[float] | float = 0.25, + mode: BlendMode | str = BlendMode.CONSTANT, + sigma_scale: Sequence[float] | float = 0.125, + padding_mode: PytorchPadMode | str = PytorchPadMode.CONSTANT, + cval: float = 0.0, + sw_device: torch.device | str | None = None, + device: torch.device | str | None = None, + progress: bool = False, + roi_weight_map: torch.Tensor | None = None, + process_fn: Callable | None = None, + buffer_steps: int | None = None, + buffer_dim: int = -1, + *args: Any, + **kwargs: Any, +) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: + """ + Sliding window inference on `inputs` with `predictor`. + + The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. + Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. + e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes + could be ([128,64,256], [64,32,128]). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still + an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters + so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). + + When roi_size is larger than the inputs' spatial size, the input image are padded during inference. + To maintain the same spatial sizes, the output image will be cropped to the original input size. + + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: given input tensor ``patch_data`` in shape NCHW[D], + The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary + with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; + where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, + N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), + the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen + to ensure the scaled output ROI sizes are still integers. + If the `predictor`'s input and output spatial sizes are different, + we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. + overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``. + mode: {``"constant"``, ``"gaussian"``} + How to blend output of overlapping windows. Defaults to ``"constant"``. + + - ``"constant``": gives equal weight to all predictions. + - ``"gaussian``": gives less weight to predictions on edges of windows. + + sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. + Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. + When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding + spatial dimensions. + padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} + Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` + See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html + cval: fill value for 'constant' padding mode. Default: 0 + sw_device: device for the window data. + By default the device (and accordingly the memory) of the `inputs` is used. + Normally `sw_device` should be consistent with the device where `predictor` is defined. + device: device for the stitched output prediction. + By default the device (and accordingly the memory) of the `inputs` is used. If for example + set to device=torch.device('cpu') the gpu memory consumption is less and independent of the + `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a `tqdm` progress bar. + roi_weight_map: pre-computed (non-negative) weight map for each ROI. + If not given, and ``mode`` is not `constant`, this map will be computed on the fly. + process_fn: process inference output and adjust the importance map per window + buffer_steps: the number of sliding window iterations along the ``buffer_dim`` + to be buffered on ``sw_device`` before writing to ``device``. + (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.) + default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size, + (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. + buffer_dim: the spatial dimension along which the buffers are created. + 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. + args: optional args to be passed to ``predictor``. + kwargs: optional keyword args to be passed to ``predictor``. + + Note: + - input must be channel-first and have a batch dim, supports N-D sliding window. + + """ + buffered = buffer_steps is not None and buffer_steps > 0 + num_spatial_dims = len(inputs.shape) - 2 + if buffered: + if buffer_dim < -num_spatial_dims or buffer_dim > num_spatial_dims: + raise ValueError(f"buffer_dim must be in [{-num_spatial_dims}, {num_spatial_dims}], got {buffer_dim}.") + if buffer_dim < 0: + buffer_dim += num_spatial_dims + overlap = ensure_tuple_rep(overlap, num_spatial_dims) + for o in overlap: + if o < 0 or o >= 1: + raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.") + compute_dtype = inputs.dtype + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + batch_size, _, *image_size_ = inputs.shape + device = device or inputs.device + sw_device = sw_device or inputs.device + + temp_meta = None + if isinstance(inputs, MetaTensor): + temp_meta = MetaTensor([]).copy_meta_from(inputs, copy_attr=False) + inputs = convert_data_type(inputs, torch.Tensor, wrap_sequence=True)[0] + roi_size = fall_back_tuple(roi_size, image_size_) + + # in case that image size is smaller than roi size + image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + if "labels" in kwargs.keys() and kwargs["labels"] is not None: + kwargs["labels"] = F.pad( + kwargs["labels"], pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval + ) + if "prev_mask" in kwargs.keys() and kwargs["prev_mask"] is not None: + kwargs["prev_mask"] = F.pad( + kwargs["prev_mask"], pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval + ) + if "point_coords" in kwargs.keys() and kwargs["point_coords"] is not None: + kwargs["point_coords"] = kwargs["point_coords"] + torch.tensor( + [pad_size[-2], pad_size[-4], pad_size[-6]] + ).to(kwargs["point_coords"].device) + + # Store all slices + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=not buffered) + + num_win = len(slices) # number of windows per image + total_slices = num_win * batch_size # total number of windows + windows_range: Iterable + if not buffered: + non_blocking = False + windows_range = range(0, total_slices, sw_batch_size) + else: + slices, n_per_batch, b_slices, windows_range = _create_buffered_slices( + slices, batch_size, sw_batch_size, buffer_dim, buffer_steps + ) + non_blocking, _ss = torch.cuda.is_available(), -1 + for x in b_slices[:n_per_batch]: + if x[1] < _ss: # detect overlapping slices + non_blocking = False + break + _ss = x[2] + + # Create window-level importance map + valid_patch_size = get_valid_patch_size(image_size, roi_size) + if valid_patch_size == roi_size and (roi_weight_map is not None): + importance_map_ = roi_weight_map + else: + try: + valid_p_size = ensure_tuple(valid_patch_size) + importance_map_ = compute_importance_map( + valid_p_size, mode=mode, sigma_scale=sigma_scale, device=sw_device, dtype=compute_dtype + ) + if len(importance_map_.shape) == num_spatial_dims and not process_fn: + importance_map_ = importance_map_[None, None] # adds batch, channel dimensions + except Exception as e: + raise RuntimeError( + f"patch size {valid_p_size}, mode={mode}, sigma_scale={sigma_scale}, device={device}\n" + "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." + ) from e + importance_map_ = convert_data_type(importance_map_, torch.Tensor, device=sw_device, dtype=compute_dtype)[0] + + # stores output and count map + output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore + # for each patch + for slice_g in tqdm(windows_range) if progress else windows_range: + slice_range = range(slice_g, min(slice_g + sw_batch_size, b_slices[b_s][0] if buffered else total_slices)) + unravel_slice = [ + [slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win]) + for idx in slice_range + ] + if sw_batch_size > 1: + win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + else: + win_data = inputs[unravel_slice[0]].to(sw_device) + + kwargs["patch_coords"] = unravel_slice[0] + seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch + + # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. + dict_keys, seg_tuple = _flatten_struct(seg_prob_out) + if process_fn: + seg_tuple, w_t = process_fn(seg_tuple, win_data, importance_map_) + else: + w_t = importance_map_ + if len(w_t.shape) == num_spatial_dims: + w_t = w_t[None, None] + w_t = w_t.to(dtype=compute_dtype, device=sw_device) + if buffered: + c_start, c_end = b_slices[b_s][1:] + if not sw_device_buffer: + k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored + sp_size = list(image_size) + sp_size[buffer_dim] = c_end - c_start + sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)] + for p, s in zip(seg_tuple[0], unravel_slice): + offset = s[buffer_dim + 2].start - c_start + s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim]) + s[0] = slice(0, 1) + sw_device_buffer[0][s] += p * w_t + b_i += len(unravel_slice) + if b_i < b_slices[b_s][0]: + continue + else: + sw_device_buffer = list(seg_tuple) + + for ss in range(len(sw_device_buffer)): + b_shape = sw_device_buffer[ss].shape + seg_chns, seg_shape = b_shape[1], b_shape[2:] + z_scale = None + if not buffered and seg_shape != roi_size: + z_scale = [out_w_i / float(in_w_i) for out_w_i, in_w_i in zip(seg_shape, roi_size)] + w_t = F.interpolate(w_t, seg_shape, mode=_nearest_mode) + if len(output_image_list) <= ss: + output_shape = [batch_size, seg_chns] + output_shape += [int(_i * _z) for _i, _z in zip(image_size, z_scale)] if z_scale else list(image_size) + # allocate memory to store the full output and the count for overlapping parts + new_tensor: Callable = torch.empty if non_blocking else torch.zeros # type: ignore + output_image_list.append(new_tensor(output_shape, dtype=compute_dtype, device=device)) + count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) + w_t_ = w_t.to(device) + for __s in slices: + if z_scale is not None: + __s = tuple(slice(int(_si.start * z_s), int(_si.stop * z_s)) for _si, z_s in zip(__s, z_scale)) + count_map_list[-1][(slice(None), slice(None), *__s)] += w_t_ + if buffered: + o_slice = [slice(None)] * len(inputs.shape) + o_slice[buffer_dim + 2] = slice(c_start, c_end) + img_b = b_s // n_per_batch # image batch index + o_slice[0] = slice(img_b, img_b + 1) + if non_blocking: + output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking) + else: + output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device) + else: + sw_device_buffer[ss] *= w_t + sw_device_buffer[ss] = sw_device_buffer[ss].to(device) + _compute_coords(unravel_slice, z_scale, output_image_list[ss], sw_device_buffer[ss]) + sw_device_buffer = [] + if buffered: + b_s += 1 + + if non_blocking: + torch.cuda.current_stream().synchronize() + + # account for any overlapping sections + for ss in range(len(output_image_list)): + output_image_list[ss] /= count_map_list.pop(0) + + # remove padding if image_size smaller than roi_size + if any(pad_size): + for ss, output_i in enumerate(output_image_list): + zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)] + final_slicing: list[slice] = [] + for sp in range(num_spatial_dims): + si = num_spatial_dims - sp - 1 + slice_dim = slice( + int(round(pad_size[sp * 2] * zoom_scale[si])), + int(round((pad_size[sp * 2] + image_size_[si]) * zoom_scale[si])), + ) + final_slicing.insert(0, slice_dim) + output_image_list[ss] = output_i[(slice(None), slice(None), *final_slicing)] + + final_output = _pack_struct(output_image_list, dict_keys) + if temp_meta is not None: + final_output = convert_to_dst_type(final_output, temp_meta, device=device)[0] + else: + final_output = convert_to_dst_type(final_output, inputs, device=device)[0] + + return final_output # type: ignore + + +def _create_buffered_slices(slices, batch_size, sw_batch_size, buffer_dim, buffer_steps): + """rearrange slices for buffering""" + slices_np = np.asarray(slices) + slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")] + slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np] + slices_np = slices_np[:, buffer_dim] + + _, _, _b_lens = np.unique(slices_np[:, 0], return_counts=True, return_index=True) + b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries + x = [0, *b_ends][:: min(len(b_ends), int(buffer_steps))] + if x[-1] < b_ends[-1]: + x.append(b_ends[-1]) + n_per_batch = len(x) - 1 + windows_range = [ + range(b * x[-1] + x[i], b * x[-1] + x[i + 1], sw_batch_size) + for b in range(batch_size) + for i in range(n_per_batch) + ] + b_slices = [] + for _s, _r in enumerate(windows_range): + s_s = slices_np[windows_range[_s - 1].stop % len(slices) if _s > 0 else 0, 0] + s_e = slices_np[(_r.stop - 1) % len(slices), 1] + b_slices.append((_r.stop, s_s, s_e)) # buffer index, slice start, slice end + windows_range = itertools.chain(*windows_range) # type: ignore + return slices, n_per_batch, b_slices, windows_range + + +def _compute_coords(coords, z_scale, out, patch): + """sliding window batch spatial scaling indexing for multi-resolution outputs.""" + for original_idx, p in zip(coords, patch): + idx_zm = list(original_idx) # 4D for 2D image, 5D for 3D image + if z_scale: + for axis in range(2, len(idx_zm)): + idx_zm[axis] = slice( + int(original_idx[axis].start * z_scale[axis - 2]), int(original_idx[axis].stop * z_scale[axis - 2]) + ) + out[idx_zm] += p + + +def _get_scan_interval( + image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: Sequence[float] +) -> tuple[int, ...]: + """ + Compute scan interval according to the image size, roi size and overlap. + Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, + use 1 instead to make sure sliding window works. + + """ + if len(image_size) != num_spatial_dims: + raise ValueError(f"len(image_size) {len(image_size)} different from spatial dims {num_spatial_dims}.") + if len(roi_size) != num_spatial_dims: + raise ValueError(f"len(roi_size) {len(roi_size)} different from spatial dims {num_spatial_dims}.") + + scan_interval = [] + for i, o in zip(range(num_spatial_dims), overlap): + if roi_size[i] == image_size[i]: + scan_interval.append(int(roi_size[i])) + else: + interval = int(roi_size[i] * (1 - o)) + scan_interval.append(interval if interval > 0 else 1) + return tuple(scan_interval) + + +def _flatten_struct(seg_out): + dict_keys = None + seg_probs: tuple[torch.Tensor, ...] + if isinstance(seg_out, torch.Tensor): + seg_probs = (seg_out,) + elif isinstance(seg_out, Mapping): + dict_keys = sorted(seg_out.keys()) # track predictor's output keys + seg_probs = tuple(seg_out[k] for k in dict_keys) + else: + seg_probs = ensure_tuple(seg_out) + return dict_keys, seg_probs + + +def _pack_struct(seg_out, dict_keys=None): + if dict_keys is not None: + return dict(zip(dict_keys, seg_out)) + if isinstance(seg_out, (list, tuple)) and len(seg_out) == 1: + return seg_out[0] + return ensure_tuple(seg_out) diff --git a/models/vista3d/scripts/trainer.py b/models/vista3d/scripts/trainer.py new file mode 100644 index 00000000..ad1e06c0 --- /dev/null +++ b/models/vista3d/scripts/trainer.py @@ -0,0 +1,217 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence + +import numpy as np +import torch +from monai.config import IgniteInfo +from monai.engines.trainer import Trainer +from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch +from monai.inferers import Inferer, SimpleInferer +from monai.transforms import Transform +from monai.utils import RankFilter, min_version, optional_import +from monai.utils.enums import CommonKeys as Keys +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from .utils import generate_prompt_pairs + +if TYPE_CHECKING: + from ignite.engine import Engine, EventEnum + from ignite.metrics import Metric +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") + +__all__ = ["Vista3dTrainer"] + + +class Vista3dTrainer(Trainer): + """ + Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``. + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for trainer to run. + train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader. + detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`. + optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim` + or its subclass. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. + iteration_update: the callable function for every iteration, expect to accept `engine` + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. + inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. + postprocessing: execute additional transformation for the model output data. + Typically, several Tensor based transforms composed by `Compose`. + key_train_metric: compute metric when every iteration completed, and save average value to + engine.state.metrics when epoch completlabel_set = np.arange(output_classes).tolist(). + key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, + recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. + default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + network: torch.nn.Module, + optimizer: Optimizer, + loss_function: Callable, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = default_prepare_batch, + iteration_update: Callable[[Engine, Any], Any] | None = None, + inferer: Inferer | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + hyper_kwargs: dict | None = None, + ) -> None: + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.network = network + self.optimizer = optimizer + self.loss_function = loss_function + self.inferer = SimpleInferer() if inferer is None else inferer + self.optim_set_to_none = optim_set_to_none + self.hyper_kwargs = hyper_kwargs + self.logger.addFilter(RankFilter()) + + def _iteration(self, engine, batchdata: dict[str, torch.Tensor]): + """ + Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + Args: + engine: `Vista3DTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + Raises: + ValueError: When ``batchdata`` is None. + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + inputs, labels = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs) + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels} + + label_set = engine.hyper_kwargs["label_set"] + output_classes = engine.hyper_kwargs["output_classes"] + if label_set is None: + label_set = np.arange(output_classes).tolist() + label_prompt, point, point_label, prompt_class, _ = generate_prompt_pairs( + labels, + label_set, + image_size=engine.hyper_kwargs["patch_size"], + max_point=engine.hyper_kwargs["max_point"], + max_prompt=engine.hyper_kwargs["max_prompt"], + max_backprompt=engine.hyper_kwargs["max_backprompt"], + max_foreprompt=engine.hyper_kwargs["max_foreprompt"], + drop_label_prob=engine.hyper_kwargs["drop_label_prob"], + drop_point_prob=engine.hyper_kwargs["drop_point_prob"], + include_background=not engine.hyper_kwargs["exclude_background"], + ) + + def _compute_pred_loss(): + outputs = engine.network( + input_images=inputs, + point_coords=point, + point_labels=point_label, + class_vector=label_prompt, + use_cfp=engine.hyper_kwargs["use_cfp"], + ) + # engine.state.output[Keys.PRED] = outputs + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + loss, loss_n = torch.tensor(0.0, device=engine.state.device), torch.tensor(0.0, device=engine.state.device) + for id in range(len(prompt_class)): + loss += engine.loss_function(outputs[[id]].float(), labels == prompt_class[id]) + loss_n += 1.0 + loss /= max(loss_n, 1.0) + engine.state.output[Keys.LOSS] = loss + outputs = None + torch.cuda.empty_cache() + engine.fire_event(IterationEvents.LOSS_COMPLETED) + + engine.network.train() + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_pred_loss() + engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.scaler.step(engine.optimizer) + engine.scaler.update() + else: + _compute_pred_loss() + engine.state.output[Keys.LOSS].backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.optimizer.step() + engine.fire_event(IterationEvents.MODEL_COMPLETED) + return engine.state.output diff --git a/models/vista3d/scripts/utils.py b/models/vista3d/scripts/utils.py new file mode 100644 index 00000000..1e831da6 --- /dev/null +++ b/models/vista3d/scripts/utils.py @@ -0,0 +1,470 @@ +import copy +import random + +import monai +import numpy as np +import torch +import torch.nn.functional as F +from monai.utils import ensure_tuple_rep + +ENABLE_SPECIAL = True +SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128) +MERGE_LIST = { + 1: [25, 26], # hepatic tumor and vessel merge into liver + 4: [24], # pancreatic tumor merge into pancreas + 132: [57], # overlap with trachea merge into airway +} + + +def get_point_label(id): + # [B, N] + if id in SPECIAL_INDEX and ENABLE_SPECIAL: + return 2, 3 + else: + return 0, 1 + + +def convert_point_label(point_label, label_set=None): + if label_set is None or not ENABLE_SPECIAL: + return point_label + assert point_label.shape[0] == len(label_set) + for i in range(len(label_set)): + if label_set[i] in SPECIAL_INDEX: + for j in range(len(point_label[i])): + point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j] + return point_label + + +def sample_points_patch_val( + labels, + patch_coords, + label_set, + prev_mask, + class_vector, + use_center=True, + mapped_label_set=None, + max_ppoint=1, + max_npoint=0, + **kwargs, +): + """ + Sample points for patch during sliding window validation. The prev_mask is only used for auto + interactive. + This function is called within vista3d.py and will use largested cc combine, do not use for iterative point evaluation. + """ + # only in validation when labels of the whole image is provided, sample points for every position + _, point_coords, point_labels, _ = generate_prompt_pairs_val( + labels[patch_coords], + label_set, + max_ppoint=max_ppoint, + max_npoint=max_npoint, + device=labels.device, + use_center=use_center, + ) + point_labels = convert_point_label(point_labels, label_set) + return point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1) + + +def erode3d(input_tensor, erosion=3): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + (erosion[2] // 2, erosion[2] // 2, erosion[1] // 2, erosion[1] // 2, erosion[0] // 2, erosion[0] // 2), + mode="constant", + value=1.0, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def generate_prompt_pairs_val(labels, label_set=None, max_ppoint=1, max_npoint=0, device="cpu", use_center=False): + """ + Args: + labels: torch.tensor from dataload, [1,1,H,W,D] + label_set: the label list for the specific dataset + Returns: + label_prompt: [b, 1] + point: [b, N, 3] + point_label: [b, N] + prompt_class: [b, 1], exactly the same with label_prompt for label indexing for training lloss. + + """ + # class label number + assert labels.shape[0] == 1, "only support batch size 1" + labels = labels[0, 0] + label_prompt = torch.tensor(label_set).to(device).unsqueeze(-1) + unique_labels = labels.unique().cpu().numpy().tolist() + _point = [] + _point_label = [] + num_n = max_npoint + num_p = max_ppoint + for id in label_set: + if id in unique_labels: + plabels = labels == int(id) + nlabels = ~plabels + _plabels = erode3d(plabels) + # _plabels = monai.transforms.utils.get_largest_connected_component_mask(_plabels) + plabelpoints = torch.nonzero(_plabels).to(device) + if len(plabelpoints) == 0: + plabelpoints = torch.nonzero(plabels).to(device) + nlabelpoints = torch.nonzero(nlabels).to(device) + if use_center: + pmean = plabelpoints.float().mean(0) + pdis = ((plabelpoints - pmean) ** 2).sum(-1) + _, sorted_indices = torch.sort(pdis) + _point.append( + torch.stack( + [plabelpoints[sorted_indices[i]] for i in range(min(len(plabelpoints), num_p))] + + random.choices(nlabelpoints, k=min(len(nlabelpoints), num_n)) + + [torch.tensor([0, 0, 0], device=device)] + * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ) + ) + _point_label.append( + torch.tensor( + [1] * min(len(plabelpoints), num_p) + + [0.0] * min(len(nlabelpoints), num_n) + + [-1] * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ).to(device) + ) + + else: + _point.append( + torch.stack( + random.choices(plabelpoints, k=min(len(plabelpoints), num_p)) + + random.choices(nlabelpoints, k=min(len(nlabelpoints), num_n)) + + [torch.tensor([0, 0, 0], device=device)] + * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ) + ) + _point_label.append( + torch.tensor( + [1] * min(len(plabelpoints), num_p) + + [0.0] * min(len(nlabelpoints), num_n) + + [-1] * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ).to(device) + ) + else: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + return label_prompt, point, point_label, prompt_class + + +def generate_prompt_pairs( + labels, + label_set=None, + image_size=None, + max_prompt=None, + max_foreprompt=None, + max_backprompt=1, + max_point=20, + include_background=True, + drop_label_prob=0.2, + drop_point_prob=0.2, + convert_to_disc=False, + radius=2, + metric_class=None, + ignore_labelset=False, + point_sampler=None, +): + """ + Args: + labels: torch.tensor from dataload, [1,1,H,W,D] + label_set: the label list for the specific dataset + total_prompt: int, number of total prompt + max_point: maximum number of points for each object + include_background: if include label=0 into training prompt. May casue issue in partial label + trainig. + metric_class: validation dice of each class. Must be the same dim with label_set + Returns: + label_prompt: [b, 1] + point: [b, N, 3] + point_label: [b, N] + prompt_class: [b, 1], exactly the same with label_prompt for label indexing for training lloss. + + """ + # class label number + assert labels.shape[0] == 1, "only support batch size 1" + labels = labels[0, 0] + point_mask = None + device = labels.device + unique_labels = labels.unique() + if include_background: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set))) + else: + unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0}) + background_labels = list(set(label_set) - set(unique_labels)) + # during training, balance background and foreground prompts + if max_backprompt is not None: + if len(background_labels) > max_backprompt: + random.shuffle(background_labels) + background_labels = background_labels[:max_backprompt] + + if max_foreprompt is not None: + if len(unique_labels) > max_foreprompt: + random.shuffle(unique_labels) + unique_labels = unique_labels[:max_foreprompt] + + if max_prompt is not None: + if len(unique_labels) + len(background_labels) > max_prompt: + if len(unique_labels) > max_prompt: + # unique_labels = random.sample(unique_labels, max_prompt) + if metric_class is None: + prob = np.ones(len(unique_labels)) + else: + prob = ( + 1 - metric_class[np.array(unique_labels).astype(int)] + if len(label_set) == len(metric_class) + else 1 - metric_class[np.array(unique_labels).astype(int) - 1] + ) + prob = [w / sum(prob) for w in prob] + unique_labels = np.random.choice(unique_labels, size=max_prompt, replace=False, p=prob).tolist() + background_labels = [] + else: + background_labels = random.sample(background_labels, max_prompt - len(unique_labels)) + _point = [] + _point_label = [] + num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1) + num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2)))) + for id in unique_labels: + neg_id, pos_id = get_point_label(id) + plabels = labels == int(id) + nlabels = ~plabels + plabelpoints = torch.nonzero(plabels) + nlabelpoints = torch.nonzero(nlabels) + _point.append( + torch.stack( + random.choices(plabelpoints, k=min(len(plabelpoints), num_p)) + + random.choices(nlabelpoints, k=min(len(nlabelpoints), num_n)) + + [torch.tensor([0, 0, 0], device=device)] + * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ) + ) + _point_label.append( + torch.tensor( + [pos_id] * min(len(plabelpoints), num_p) + + [neg_id] * min(len(nlabelpoints), num_n) + + [-1] * (num_p + num_n - min(len(plabelpoints), num_p) - min(len(nlabelpoints), num_n)) + ).to(device) + ) + for _id in background_labels: + # pad the background labels + _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0 + _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point + label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long() + point = torch.stack(_point) + point_label = torch.stack(_point_label) + prompt_class = copy.deepcopy(label_prompt) + if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0: + label_prompt = None + # drop out the padded + pad = len(background_labels) + point = point[: len(point) - pad] + point_label = point_label[: len(point_label) - pad] + prompt_class = prompt_class[: len(prompt_class) - pad] + else: + if random.uniform(0, 1) < drop_point_prob: + point = None + point_label = None + if point is not None and convert_to_disc: + point_mask = convert_points_to_disc(image_size, point, point_label, radius=radius) + return label_prompt, point, point_label, prompt_class, point_mask + + +def get_gaussian_ball(image_size, radius=None): + if radius is None: + radius = image_size[0] // 3 + row_array = torch.arange(start=0, end=image_size[0], step=1, dtype=torch.float32) + col_array = torch.arange(start=0, end=image_size[1], step=1, dtype=torch.float32) + z_array = torch.arange(start=0, end=image_size[2], step=1, dtype=torch.float32) + coord_rows, coord_cols, coord_z = torch.meshgrid(z_array, col_array, row_array, indexing="ij") + coords = torch.stack((coord_rows, coord_cols, coord_z), dim=0) + center = ( + torch.tensor([image_size[0] // 2, image_size[1] // 2, image_size[2] // 2]) + .to(coords.device) + .unsqueeze(-1) + .unsqueeze(-1) + .unsqueeze(-1) + ) + ball = torch.exp(-((((coords - center) ** 2).sum(0) / (2 * radius**2)) ** 2)) + return ball + + +def convert_points_to_disc(image_size, point, point_label, radius=2, disc=False): + # [b, N, 3], [b, N] + # generate masks [b,2,h,w,d] + if not torch.is_tensor(point): + point = torch.from_numpy(point) + masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device) + row_array = torch.arange(start=0, end=image_size[0], step=1, dtype=torch.float32, device=point.device) + col_array = torch.arange(start=0, end=image_size[1], step=1, dtype=torch.float32, device=point.device) + z_array = torch.arange(start=0, end=image_size[2], step=1, dtype=torch.float32, device=point.device) + coord_rows, coord_cols, coord_z = torch.meshgrid(z_array, col_array, row_array, indexing="ij") + # [1,3,h,w,d] -> [b, 2, 3, h,w,d] + coords = ( + torch.stack((coord_rows, coord_cols, coord_z), dim=0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(point.shape[0], 2, 1, 1, 1, 1) + ) + for b in range(point.shape[0]): + for n in range(point.shape[1]): + if point_label[b, n] > -1: + channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1 + if disc: + masks[b, channel] += ( + torch.pow(coords[b, channel] - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 2).sum(0) + < radius**2 + ) + else: + masks[b, channel] += torch.exp( + -torch.pow(coords[b, channel] - point[b, n].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 2).sum(0) + / (2 * radius**2) + ) + # masks[masks>1] = 1 + return masks + + +def get_window_idx_c(p, roi, s): + if p - roi // 2 < 0: + l, r = 0, roi + elif p + roi // 2 > s: + l, r = s - roi, s + else: + l, r = int(p) - roi // 2, int(p) + roi // 2 + return l, r + + +def get_window_idx(p, roi, s, center_only=True, margin=5): + l, r = get_window_idx_c(p, roi, s) + if center_only: + return [l], [r] + left_most = max(0, p - roi + margin) + right_most = min(s, p + roi - margin) + left = [left_most, right_most - roi, l] + right = [left_most + roi, right_most, r] + return left, right + + +def pad_previous_mask(inputs, roi_size, padvalue=0): + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + if any(pad_size): + inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) + return inputs, pad_size + + +def point_based_window_inferer( + inputs, + roi_size, + sw_batch_size, + predictor, + mode, + overlap, + sw_device, + device, + point_coords, + point_labels, + class_vector, + prompt_class, + prev_mask, + point_mask=None, + point_start=0, + **kwargs, +): + """ + Point based window inferer, crop a patch centered at the point, and perform inference. + Different patches are combined with gaussian weighted weights. + + Args: + predictor: partial(infer_wrapper, model). infer_wrapper transpose the model output. + The model output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D] + point_coords: [B, N, 3] + point_labels: [B, N] + class_vector: [B] + prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID! + Returns: + stitched_output: [1, B, H, W, D]. The value is before sigmoid. + Notice: The function currently only supports SINGLE OBJECT INFERENCE with B=1. + """ + assert point_coords.shape[0] == 1, "Only supports single object point click" + image, pad = pad_previous_mask(copy.deepcopy(inputs), roi_size) + point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device) + prev_mask = pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None + stitched_output = None + center_only = True + for p in point_coords[0][point_start:]: + lx_, rx_ = get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=5) + ly_, ry_ = get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=5) + lz_, rz_ = get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=5) + for i in range(len(lx_)): + for j in range(len(ly_)): + for k in range(len(lz_)): + lx, rx, ly, ry, lz, rz = lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k] + unravel_slice = [ + slice(None), + slice(None), + slice(int(lx), int(rx)), + slice(int(ly), int(ry)), + slice(int(lz), int(rz)), + ] + batch_image = image[unravel_slice] + output = predictor( + batch_image, + point_coords=point_coords, + point_labels=point_labels, + class_vector=class_vector, + prompt_class=prompt_class, + patch_coords=unravel_slice, + prev_mask=prev_mask, + **kwargs, + ) + if stitched_output is None: + stitched_output = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_mask = torch.zeros( + [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu" + ) + stitched_output[unravel_slice] += output.to("cpu") + stitched_mask[unravel_slice] = 1 + # if stitched_mask is 0, then NaN value + stitched_output = stitched_output / stitched_mask + # revert padding + stitched_output = stitched_output[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + stitched_mask = stitched_mask[ + :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1] + ] + if prev_mask is not None: + prev_mask = prev_mask[ + :, + :, + pad[4] : image.shape[-3] - pad[5], + pad[2] : image.shape[-2] - pad[3], + pad[0] : image.shape[-1] - pad[1], + ] + prev_mask = prev_mask.to("cpu") + # for un-calculated place, use previous mask + stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1] + + if not hasattr(stitched_output, "meta"): + stitched_output = monai.data.MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta) + return stitched_output diff --git a/models/vista3d/scripts/vista3d/__init__.py b/models/vista3d/scripts/vista3d/__init__.py new file mode 100644 index 00000000..79662a36 --- /dev/null +++ b/models/vista3d/scripts/vista3d/__init__.py @@ -0,0 +1 @@ +from .build_vista3d import vista_model_registry diff --git a/models/vista3d/scripts/vista3d/build_vista3d.py b/models/vista3d/scripts/vista3d/build_vista3d.py new file mode 100755 index 00000000..85f531bf --- /dev/null +++ b/models/vista3d/scripts/vista3d/build_vista3d.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .modeling import VISTA3D2, ClassMappingClassify, PointMappingSAM, SegResNetDS2 + + +def build_vista3d_segresnet_decoder(encoder_embed_dim=48, in_channels=1, image_size=(96, 96, 96)): + segresnet = SegResNetDS2( + in_channels=in_channels, + blocks_down=(1, 2, 2, 4, 4), + norm="instance", + out_channels=encoder_embed_dim, + init_filters=encoder_embed_dim, + dsdepth=1, + ) + point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132) + class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True) + vista = VISTA3D2( + image_encoder=segresnet, class_head=class_head, point_head=point_head, feature_size=encoder_embed_dim + ) + return vista + + +vista_model_registry = {"vista3d_segresnet_d": build_vista3d_segresnet_decoder} diff --git a/models/vista3d/scripts/vista3d/modeling/__init__.py b/models/vista3d/scripts/vista3d/modeling/__init__.py new file mode 100755 index 00000000..c8165d51 --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/__init__.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .class_head import ClassMappingClassify +from .point_head import PointMappingSAM +from .segresnetds import SegResNetDS2 +from .vista3d import VISTA3D2 diff --git a/models/vista3d/scripts/vista3d/modeling/class_head.py b/models/vista3d/scripts/vista3d/modeling/class_head.py new file mode 100644 index 00000000..46207613 --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/class_head.py @@ -0,0 +1,51 @@ +import monai +import torch +import torch.nn as nn + + +class ClassMappingClassify(nn.Module): + def __init__(self, n_classes, feature_size, use_mlp=False): + super().__init__() + self.use_mlp = use_mlp + if use_mlp: + self.mlp = nn.Sequential( + nn.Linear(feature_size, feature_size), + nn.InstanceNorm1d(1), + nn.GELU(), + nn.Linear(feature_size, feature_size), + ) + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.image_post_mapping = nn.Sequential( + monai.networks.blocks.UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + monai.networks.blocks.UnetrBasicBlock( + spatial_dims=3, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ), + ) + + def forward(self, src, class_vector): + b, c, h, w, d = src.shape + src = self.image_post_mapping(src) + class_embedding = self.class_embeddings(class_vector) + if self.use_mlp: + class_embedding = self.mlp(class_embedding) + # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension. + masks = [] + for i in range(b): + mask = (class_embedding @ src[[i]].view(1, c, h * w * d)).view(-1, 1, h, w, d) + masks.append(mask) + masks = torch.cat(masks, 1) + return masks, class_embedding diff --git a/models/vista3d/scripts/vista3d/modeling/point_head.py b/models/vista3d/scripts/vista3d/modeling/point_head.py new file mode 100644 index 00000000..36258361 --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/point_head.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +import numpy as np +import torch +import torch.nn as nn +from monai.utils import optional_import + +from .sam_blocks import MLP, PositionEmbeddingRandom, TwoWayTransformer + +rearrange, _ = optional_import("einops", name="rearrange") + + +class PointMappingSAM(nn.Module): + def __init__(self, feature_size, max_prompt=32, num_add_mask_tokens=2, n_classes=512, last_supported=132): + super().__init__() + transformer_dim = feature_size + self.max_prompt = max_prompt + self.feat_downsample = nn.Sequential( + nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1), + nn.InstanceNorm3d(feature_size), + nn.GELU(), + nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1), + nn.InstanceNorm3d(feature_size), + ) + + self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1) + + self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4) + self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2) + self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)]) + self.not_a_point_embed = nn.Embedding(1, transformer_dim) + self.special_class_embed = nn.Embedding(1, transformer_dim) + self.mask_tokens = nn.Embedding(1, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.InstanceNorm3d(transformer_dim), + nn.GELU(), + nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1), + ) + + self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3) + + # MultiMask output + self.num_add_mask_tokens = num_add_mask_tokens + self.output_add_hypernetworks_mlps = nn.ModuleList( + [MLP(transformer_dim, transformer_dim, transformer_dim, 3) for i in range(self.num_add_mask_tokens)] + ) + # class embedding + self.n_classes = n_classes + self.last_supported = last_supported + self.class_embeddings = nn.Embedding(n_classes, feature_size) + self.zeroshot_embed = nn.Embedding(1, transformer_dim) + self.supported_embed = nn.Embedding(1, transformer_dim) + + def forward(self, out, point_coords, point_labels, class_vector=None): + # downsample out + out_low = self.feat_downsample(out) + out_shape = out.shape[-3:] + out = None + torch.cuda.empty_cache() + # embed points + points = point_coords + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords(points, out_shape) + point_embedding[point_labels == -1] = 0.0 + point_embedding[point_labels == -1] += self.not_a_point_embed.weight + point_embedding[point_labels == 0] += self.point_embeddings[0].weight + point_embedding[point_labels == 1] += self.point_embeddings[1].weight + point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight + point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight + output_tokens = self.mask_tokens.weight + + output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1) + if class_vector is None: + tokens_all = torch.cat( + ( + output_tokens, + point_embedding, + self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1), + ), + dim=1, + ) + # tokens_all = torch.cat((output_tokens, point_embedding), dim=1) + else: + class_embeddings = [] + for i in class_vector: + if i > self.last_supported: + class_embeddings.append(self.zeroshot_embed.weight) + else: + class_embeddings.append(self.supported_embed.weight) + class_embeddings = torch.stack(class_embeddings) + tokens_all = torch.cat((output_tokens, point_embedding, class_embeddings), dim=1) + # cross attention + masks = [] + max_prompt = self.max_prompt + for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))): + # remove variables in previous for loops to save peak memory for self.transformer + src, upscaled_embedding, hyper_in = None, None, None + torch.cuda.empty_cache() + idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0])) + tokens = tokens_all[idx[0] : idx[1]] + src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0) + pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0) + b, c, h, w, d = src.shape + hs, src = self.transformer(src, pos_src, tokens) + mask_tokens_out = hs[:, :1, :] + hyper_in = self.output_hypernetworks_mlps(mask_tokens_out) + src = src.transpose(1, 2).view(b, c, h, w, d) + upscaled_embedding = self.output_upscaling(src) + b, c, h, w, d = upscaled_embedding.shape + masks.append((hyper_in @ upscaled_embedding.view(b, c, h * w * d)).view(b, -1, h, w, d)) + masks = torch.vstack(masks) + return masks diff --git a/models/vista3d/scripts/vista3d/modeling/sam_blocks.py b/models/vista3d/scripts/vista3d/modeling/sam_blocks.py new file mode 100644 index 00000000..402f9fb0 --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/sam_blocks.py @@ -0,0 +1,292 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple, Type + +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward(self, image_embedding: Tensor, image_pe: Tensor, point_embedding: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w, d = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward(self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats))) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + # [bs=1,N=2,2] @ [2,128] + # [bs=1, N=2, 128] + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + # [bs=1, N=2, 128+128=256] + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w, d = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w, d), device=device, dtype=torch.float32) + x_embed = grid.cumsum(dim=0) - 0.5 + y_embed = grid.cumsum(dim=1) - 0.5 + z_embed = grid.cumsum(dim=2) - 0.5 + x_embed = x_embed / h + y_embed = y_embed / w + z_embed = z_embed / d + pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1)) + return pe.permute(3, 0, 1, 2) # C x H x W + + def forward_with_coords(self, coords_input: torch.Tensor, image_size: Tuple[int, int]) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[0] + coords[:, :, 1] = coords[:, :, 1] / image_size[1] + coords[:, :, 2] = coords[:, :, 2] / image_size[2] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +class MLPBlock(nn.Module): + def __init__(self, embedding_dim: int, mlp_dim: int, act: Type[nn.Module] = nn.GELU) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +class MLP(nn.Module): + def __init__( + self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/models/vista3d/scripts/vista3d/modeling/segresnetds.py b/models/vista3d/scripts/vista3d/modeling/segresnetds.py new file mode 100644 index 00000000..44686e8a --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/segresnetds.py @@ -0,0 +1,488 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Callable +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +from monai.networks.blocks.upsample import UpSample +from monai.networks.layers.factories import Act, Conv, Norm, split_args +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from monai.utils import UpsampleMode, has_option + +__all__ = ["SegResNetDS2"] + + +def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None): + """ + A helper function to compute a schedule of scale at different downsampling levels, + given the input resolution. + + .. code-block:: python + + scales_for_resolution(resolution=[1,1,5], n_stages=5) + + Args: + resolution: input image resolution (in mm) + n_stages: optionally the number of stages of the network + """ + + ndim = len(resolution) + res = np.array(resolution) + if not all(res > 0): + raise ValueError("Resolution must be positive") + + nl = np.floor(np.log2(np.max(res) / res)).astype(np.int32) + scales = [tuple(np.where(2**i >= 2**nl, 1, 2)) for i in range(max(nl))] + if n_stages and n_stages > max(nl): + scales = scales + [(2,) * ndim] * (n_stages - max(nl)) + else: + scales = scales[:n_stages] + return scales + + +def aniso_kernel(scale: tuple | list): + """ + A helper function to compute kernel_size, padding and stride for the given scale + + Args: + scale: scale from a current scale level + """ + kernel_size = [3 if scale[k] > 1 else 1 for k in range(len(scale))] + padding = [k // 2 for k in kernel_size] + return kernel_size, padding, scale + + +class SegResBlock(nn.Module): + """ + Residual network block used SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization + `_. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + norm: tuple | str, + kernel_size: tuple | int = 3, + act: tuple | str = "relu", + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions, could be 1, 2 or 3. + in_channels: number of input channels. + norm: feature normalization type and arguments. + kernel_size: convolution kernel size. Defaults to 3. + act: activation type and arguments. Defaults to ``RELU``. + """ + super().__init__() + + if isinstance(kernel_size, (tuple, list)): + padding = tuple(k // 2 for k in kernel_size) + else: + padding = kernel_size // 2 # type: ignore + + self.norm1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.act1 = get_act_layer(act) + self.conv1 = Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=False, + ) + + self.norm2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=in_channels) + self.act2 = get_act_layer(act) + self.conv2 = Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + bias=False, + ) + + def forward(self, x): + identity = x + x = self.conv1(self.act1(self.norm1(x))) + x = self.conv2(self.act2(self.norm2(x))) + x += identity + return x + + +class SegResEncoder(nn.Module): + """ + SegResEncoder based on the econder structure in `3D MRI brain tumor segmentation using autoencoder regularization + `_. + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + head_module: optional callable module to apply to the final features. + anisotropic_scales: optional list of scale for each scale level. + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + head_module: nn.Module | None = None, + anisotropic_scales: tuple | None = None, + ): + super().__init__() + + if spatial_dims not in (1, 2, 3): + raise ValueError("`spatial_dims` can only be 1, 2 or 3.") + + # ensure normalization has affine trainable parameters (if not specified) + norm = split_args(norm) + if has_option(Norm[norm[0], spatial_dims], "affine"): + norm[1].setdefault("affine", True) # type: ignore + + # ensure activation is inplace (if not specified) + act = split_args(act) + if has_option(Act[act[0]], "inplace"): + act[1].setdefault("inplace", True) # type: ignore + + filters = init_filters # base number of features + + kernel_size, padding, _ = aniso_kernel(anisotropic_scales[0]) if anisotropic_scales else (3, 1, 1) + self.conv_init = Conv[Conv.CONV, spatial_dims]( + in_channels=in_channels, + out_channels=filters, + kernel_size=kernel_size, + padding=padding, + stride=1, + bias=False, + ) + self.layers = nn.ModuleList() + + for i in range(len(blocks_down)): + level = nn.ModuleDict() + + kernel_size, padding, stride = aniso_kernel(anisotropic_scales[i]) if anisotropic_scales else (3, 1, 2) + blocks = [ + SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act) + for _ in range(blocks_down[i]) + ] + level["blocks"] = nn.Sequential(*blocks) + + if i < len(blocks_down) - 1: + level["downsample"] = Conv[Conv.CONV, spatial_dims]( + in_channels=filters, + out_channels=2 * filters, + bias=False, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + else: + level["downsample"] = nn.Identity() + + self.layers.append(level) + filters *= 2 + + self.head_module = head_module + self.in_channels = in_channels + self.blocks_down = blocks_down + self.init_filters = init_filters + self.norm = norm + self.act = act + self.spatial_dims = spatial_dims + + def _forward(self, x: torch.Tensor) -> list[torch.Tensor]: + outputs = [] + x = self.conv_init(x) + + for level in self.layers: + x = level["blocks"](x) + outputs.append(x) + x = level["downsample"](x) + + if self.head_module is not None: + outputs = self.head_module(outputs) + + return outputs + + def forward(self, x: torch.Tensor) -> list[torch.Tensor]: + return self._forward(x) + + +class SegResNetDS2(nn.Module): + """ + SegResNetDS based on `3D MRI brain tumor segmentation using autoencoder regularization + `_. + It is similar to https://docs.monai.io/en/stable/networks.html#segresnet, with several + improvements including deep supervision and non-isotropic kernel support. + + Args: + spatial_dims: spatial dimension of the input data. Defaults to 3. + init_filters: number of output channels for initial convolution layer. Defaults to 32. + in_channels: number of input channels for the network. Defaults to 1. + out_channels: number of output channels for the network. Defaults to 2. + act: activation type and arguments. Defaults to ``RELU``. + norm: feature normalization type and arguments. Defaults to ``BATCH``. + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. + blocks_up: number of upsample blocks (optional). + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. + At dsdepth==1,only a single output is returned. + preprocess: optional callable function to apply before the model's forward pass + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring + image spacing into an approximately isotropic space. + Otherwise, by default, the kernel size and downsampling is always isotropic. + + """ + + def __init__( + self, + spatial_dims: int = 3, + init_filters: int = 32, + in_channels: int = 1, + out_channels: int = 2, + act: tuple | str = "relu", + norm: tuple | str = "batch", + blocks_down: tuple = (1, 2, 2, 4), + blocks_up: tuple | None = None, + dsdepth: int = 1, + preprocess: nn.Module | Callable | None = None, + upsample_mode: UpsampleMode | str = "deconv", + resolution: tuple | None = None, + ): + super().__init__() + + if spatial_dims not in (1, 2, 3): + raise ValueError("`spatial_dims` can only be 1, 2 or 3.") + + self.spatial_dims = spatial_dims + self.init_filters = init_filters + self.in_channels = in_channels + self.out_channels = out_channels + self.act = act + self.norm = norm + self.blocks_down = blocks_down + self.dsdepth = max(dsdepth, 1) + self.resolution = resolution + self.preprocess = preprocess + + if resolution is not None: + if not isinstance(resolution, (list, tuple)): + raise TypeError("resolution must be a tuple") + elif not all(r > 0 for r in resolution): + raise ValueError("resolution must be positive") + + # ensure normalization had affine trainable parameters (if not specified) + norm = split_args(norm) + if has_option(Norm[norm[0], spatial_dims], "affine"): + norm[1].setdefault("affine", True) # type: ignore + + # ensure activation is inplace (if not specified) + act = split_args(act) + if has_option(Act[act[0]], "inplace"): + act[1].setdefault("inplace", True) # type: ignore + + anisotropic_scales = None + if resolution: + anisotropic_scales = scales_for_resolution(resolution, n_stages=len(blocks_down)) + self.anisotropic_scales = anisotropic_scales + + self.encoder = SegResEncoder( + spatial_dims=spatial_dims, + init_filters=init_filters, + in_channels=in_channels, + act=act, + norm=norm, + blocks_down=blocks_down, + anisotropic_scales=anisotropic_scales, + ) + + n_up = len(blocks_down) - 1 + if blocks_up is None: + blocks_up = (1,) * n_up # assume 1 upsample block per level + self.blocks_up = blocks_up + + filters = init_filters * 2**n_up + self.up_layers = nn.ModuleList() + self.up_layers_auto = nn.ModuleList() + + for i in range(n_up): + filters = filters // 2 + kernel_size, _, stride = ( + aniso_kernel(anisotropic_scales[len(blocks_up) - i - 1]) if anisotropic_scales else (3, 1, 2) + ) + + level = nn.ModuleDict() + level_auto = nn.ModuleDict() + level["upsample"] = UpSample( + mode=upsample_mode, + spatial_dims=spatial_dims, + in_channels=2 * filters, + out_channels=filters, + kernel_size=kernel_size, + scale_factor=stride, + bias=False, + align_corners=False, + ) + level_auto["upsample"] = UpSample( + mode=upsample_mode, + spatial_dims=spatial_dims, + in_channels=2 * filters, + out_channels=filters, + kernel_size=kernel_size, + scale_factor=stride, + bias=False, + align_corners=False, + ) + blocks = [ + SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act) + for _ in range(blocks_up[i]) + ] + level["blocks"] = nn.Sequential(*blocks) + blocks = [ + SegResBlock(spatial_dims=spatial_dims, in_channels=filters, kernel_size=kernel_size, norm=norm, act=act) + for _ in range(blocks_up[i]) + ] + level_auto["blocks"] = nn.Sequential(*blocks) + if len(blocks_up) - i <= dsdepth: # deep supervision heads + level["head"] = Conv[Conv.CONV, spatial_dims]( + in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True + ) + level_auto["head"] = Conv[Conv.CONV, spatial_dims]( + in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True + ) + else: + level["head"] = nn.Identity() + level_auto["head"] = nn.Identity() + + self.up_layers.append(level) + self.up_layers_auto.append(level_auto) + + if n_up == 0: # in a corner case of flat structure (no downsampling), attache a single head + level = nn.ModuleDict( + { + "upsample": nn.Identity(), + "blocks": nn.Identity(), + "head": Conv[Conv.CONV, spatial_dims]( + in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True + ), + } + ) + level_auto = nn.ModuleDict( + { + "upsample": nn.Identity(), + "blocks": nn.Identity(), + "head": Conv[Conv.CONV, spatial_dims]( + in_channels=filters, out_channels=out_channels, kernel_size=1, bias=True + ), + } + ) + self.up_layers.append(level) + self.up_layers_auto.append(level_auto) + + def shape_factor(self): + """ + Calculate the factors (divisors) that the input image shape must be divisible by + """ + if self.anisotropic_scales is None: + d = [2 ** (len(self.blocks_down) - 1)] * self.spatial_dims + else: + d = list(np.prod(np.array(self.anisotropic_scales[:-1]), axis=0)) + return d + + def is_valid_shape(self, x): + """ + Calculate if the input shape is divisible by the minimum factors for the current network configuration + """ + a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())] + return all(a) + + def _forward(self, x: torch.Tensor, with_point, with_label) -> Union[None, torch.Tensor, list[torch.Tensor]]: + if self.preprocess is not None: + x = self.preprocess(x) + + if not self.is_valid_shape(x): + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") + + x_down = self.encoder(x) + + x_down.reverse() + x = x_down.pop(0) + + if len(x_down) == 0: + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] + + outputs: list[torch.Tensor] = [] + outputs_auto: list[torch.Tensor] = [] + x_ = x.clone() + if with_point: + i = 0 + for level in self.up_layers: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs.append(level["head"](x)) + i = i + 1 + + outputs.reverse() + x = x_ + if with_label: + i = 0 + for level in self.up_layers_auto: + x = level["upsample"](x) + x = x + x_down[i] + x = level["blocks"](x) + + if len(self.up_layers) - i <= self.dsdepth: + outputs_auto.append(level["head"](x)) + i = i + 1 + + outputs_auto.reverse() + + # in eval() mode, always return a single final output + if not self.training or len(outputs) == 1: + outputs = outputs[0] if len(outputs) == 1 else outputs + + if not self.training or len(outputs_auto) == 1: + outputs_auto = outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto + + # return a list of DS outputs + return outputs, outputs_auto + + def forward( + self, x: torch.Tensor, with_point=True, with_label=True, **kwargs + ) -> Union[None, torch.Tensor, list[torch.Tensor]]: + return self._forward(x, with_point, with_label) + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + for param in self.encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + + for param in self.up_layers_auto.parameters(): + param.requires_grad = not auto_freeze + + for param in self.up_layers.parameters(): + param.requires_grad = not point_freeze diff --git a/models/vista3d/scripts/vista3d/modeling/vista3d.py b/models/vista3d/scripts/vista3d/modeling/vista3d.py new file mode 100644 index 00000000..d9143728 --- /dev/null +++ b/models/vista3d/scripts/vista3d/modeling/vista3d.py @@ -0,0 +1,262 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import monai +import numpy as np +import torch +import torch.nn as nn +from monai.utils import optional_import +from scripts.monai_trans_utils import get_largest_connected_component_mask as lcc +from scripts.utils import convert_points_to_disc, sample_points_patch_val + +rearrange, _ = optional_import("einops", name="rearrange") +NINF_VALUE = -9999 +PINF_VALUE = 9999 + + +class VISTA3D2(nn.Module): + def __init__(self, image_encoder, class_head, point_head, feature_size): + super().__init__() + self.image_encoder = image_encoder + self.class_head = class_head + self.point_head = point_head + self.image_embeddings = None + self.weight_mapper = nn.Sequential( + nn.Linear(feature_size, 4 * feature_size), + nn.GELU(), + nn.InstanceNorm1d(4 * feature_size), + nn.Linear(4 * feature_size, 1), + ) + self.auto_freeze = False + self.point_freeze = False + + def precompute_embedding(self, input_images): + """precompute image embedding, require sliding window inference""" + raise NotImplementedError + + def clear_cache(self): + self.image_embeddings = None + + def get_bs(self, class_vector, point_coords): + if class_vector is None: + assert point_coords is not None, "prompt is required" + return point_coords.shape[0] + else: + return class_vector.shape[0] + + def update_point_to_patch(self, patch_coords, point_coords, point_labels): + """Update point_coords with respect to patch coords. + If point is outside of the patch, remove the coordinates and set label to -1 + """ + patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop] + patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start] + # update point coords + patch_starts = torch.tensor(patch_starts, device=point_coords.device).unsqueeze(0).unsqueeze(0) + patch_ends = torch.tensor(patch_ends, device=point_coords.device).unsqueeze(0).unsqueeze(0) + # [1 N 1] + indices = torch.logical_and( + ((point_coords - patch_starts) > 0).all(2), ((patch_ends - point_coords) > 0).all(2) + ) + # check if it's within patch coords + point_coords = point_coords.clone() - patch_starts + point_labels = point_labels.clone() + if indices.any(): + point_labels[~indices] = -1 + point_coords[~indices] = 0 + # also remove padded points, mainly used for inference. + not_pad_indices = (point_labels != -1).any(0) + point_coords = point_coords[:, not_pad_indices] + point_labels = point_labels[:, not_pad_indices] + else: + point_coords = None + point_labels = None + return point_coords, point_labels + + def connected_components_combine(self, logits, point_logits, point_coords, point_labels, mapping_index, thred=0.5): + """ + Combine auto results with point click response, or combine previous mask with point click response. + For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed + region in point clicks must be updated by the lcc function. + Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added. + """ + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + _logits = logits[mapping_index] + inside = [] + for i in range(_logits.shape[0]): + inside.append( + np.any( + [ + _logits[i, 0, round(p[0].item()), round(p[1].item()), round(p[2].item())].item() > 0 + for p in point_coords[i] + ] + ) + ) + inside = torch.tensor(inside).to(logits.device) + nan_mask = torch.isnan(_logits) + _logits = torch.nan_to_num(_logits, nan=NINF_VALUE).sigmoid() + pos_region = point_logits.sigmoid() > thred + diff_pos = torch.logical_and( + torch.logical_or((_logits <= thred), inside.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)), + pos_region, + ) + diff_neg = torch.logical_and((_logits > thred), ~pos_region) + cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels) + # cc is the region that can be updated by point_logits. + cc = cc.to(logits.device) + # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask, only remove unconnected positive region. + uc_pos_region = torch.logical_and(pos_region, ~cc) + fill_mask = torch.logical_and(nan_mask, uc_pos_region) + if fill_mask.any(): + # fill in the mean negative value + point_logits[fill_mask] = -1 + # replace logits nan value and cc with point_logits + cc = torch.logical_or(nan_mask, cc).to(logits.dtype) + logits[mapping_index] *= 1 - cc + logits[mapping_index] += cc * point_logits + return logits + + def gaussian_combine(self, logits, point_logits, point_coords, point_labels, mapping_index, radius): + if radius is None: + radius = min(point_logits.shape[-3:]) // 5 # empirical value 5 + weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum( + 1, keepdims=True + ) + weight[weight < 0] = 0 + logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits + logits[mapping_index] *= weight + logits[mapping_index] += (1 - weight) * point_logits + return logits + + def set_auto_grad(self, auto_freeze=False, point_freeze=False): + if auto_freeze != self.auto_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.class_head.parameters(): + param.requires_grad = not auto_freeze + self.auto_freeze = auto_freeze + + if point_freeze != self.point_freeze: + if hasattr(self.image_encoder, "set_auto_grad"): + self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze) + else: + for param in self.image_encoder.parameters(): + param.requires_grad = (not auto_freeze) and (not point_freeze) + for param in self.point_head.parameters(): + param.requires_grad = not point_freeze + self.point_freeze = point_freeze + + def forward( + self, + input_images, + point_coords=None, + point_labels=None, + class_vector=None, + prompt_class=None, + patch_coords=None, + labels=None, + label_set=None, + prev_mask=None, + radius=None, + val_point_sampler=None, + transpose=False, + **kwargs, + ): + image_size = input_images.shape[-3:] + device = input_images.device + + if point_coords is None and class_vector is None: + # For TRT conversion, no prompts are given. + return NINF_VALUE + torch.zeros([1, 1, *image_size], device=device) + + bs = self.get_bs(class_vector, point_coords) + if patch_coords is not None and point_coords is not None: + """patch_coords is passed from monai_utils.sliding_window_inferer.""" + # Automatic point sample in validation + if labels is not None and label_set is not None: + # if labels is not None, sample from labels for each patch. + if val_point_sampler is None: + val_point_sampler = sample_points_patch_val + point_coords, point_labels, prompt_class = val_point_sampler( + labels, patch_coords, label_set, prev_mask, class_vector + ) + if prompt_class[0].item() == 0: + point_labels[0] = -1 + labels, prev_mask = None, None + # User provided click points in inference + else: + point_coords, point_labels = self.update_point_to_patch(patch_coords, point_coords, point_labels) + + if point_coords is not None and point_labels is not None: + # remove points that used for padding purposes (point_label = -1) + mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool) + if mapping_index.any(): + point_coords = point_coords[mapping_index] + point_labels = point_labels[mapping_index] + if prompt_class is not None: + prompt_class = prompt_class[mapping_index] + else: + if self.auto_freeze or (class_vector is None and patch_coords is None): + # if auto_freeze, point prompt must exist to allow loss backward + # in training, class_vector and point cannot both be None due to loss.backward() + mapping_index.fill_(True) + else: + point_coords, point_labels = None, None + + if point_coords is None and class_vector is None: + return NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device) + + if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None: + out, out_auto = self.image_embeddings, None + else: + out, out_auto = self.image_encoder( + input_images, with_point=point_coords is not None, with_label=class_vector is not None + ) + input_images = None + + # force releasing memories that set to None + torch.cuda.empty_cache() + + if class_vector is not None: + logits, _ = self.class_head(out_auto, class_vector) + if point_coords is not None: + point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if patch_coords is None: + # during training, using gaussian ball combine + logits = self.gaussian_combine( + logits, point_logits, point_coords, point_labels, mapping_index, radius + ) + else: + # during validation use largest component + logits = self.connected_components_combine( + logits, point_logits, point_coords, point_labels, mapping_index + ) + else: + logits = NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype) + logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class) + if prev_mask is not None and patch_coords is not None: + logits = self.connected_components_combine( + prev_mask[patch_coords].transpose(1, 0).to(logits.device), + logits[mapping_index], + point_coords, + point_labels, + mapping_index, + ) + + if kwargs.get("keep_cache", False) and class_vector is None: + self.image_embeddings = out.detach() + if transpose: + logits = logits.transpose(1, 0) + return logits