Skip to content

Commit

Permalink
Fix OOM issue in vista3d bundle and update readme (#657)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
A few sentences describing the changes proposed in this pull request.

### Status
**Ready/Work in progress/Hold**

### Please ensure all the checkboxes:
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Codeformat tests passed locally by running `./runtests.sh
--codeformat`.
- [ ] In-line docstrings updated.
- [ ] Update `version` and `changelog` in `metadata.json` if changing an
existing bundle.
- [ ] Please ensure the naming rules in config files meet our
requirements (please refer to: `CONTRIBUTING.md`).
- [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy`
are correct in `metadata.json`.
- [ ] Descriptions should be consistent with the content, such as
`eval_metrics` of the provided weights and TorchScript modules.
- [ ] Files larger than 25MB are excluded and replaced by providing
download links in `large_file.yml`.
- [ ] Avoid using path that contains personal information within config
files (such as use `/home/your_name/` for `"bundle_root"`).

---------

Signed-off-by: heyufan1995 <[email protected]>
Signed-off-by: Yiheng Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yiheng Wang <[email protected]>
Co-authored-by: Yiheng Wang <[email protected]>
  • Loading branch information
4 people authored Sep 23, 2024
1 parent 8ff40cb commit 23d2558
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 313 deletions.
3 changes: 2 additions & 1 deletion models/vista2d/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.2.7",
"version": "0.2.8",
"changelog": {
"0.2.8": "remove relative path in readme",
"0.2.7": "enhance readme",
"0.2.6": "update tensorrt benchmark results",
"0.2.5": "add tensorrt benchmark results",
Expand Down
2 changes: 1 addition & 1 deletion models/vista2d/docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

The **VISTA2D** is a cell segmentation training and inference pipeline for cell imaging [[`Blog`](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)].

A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the [tutorial](../download_preprocessor/). Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).
A pretrained model was trained on collection of 15K public microscopy images. The data collection and training can be reproduced following the `download_preprocessor/`. Alternatively, the model can be retrained on your own dataset. The pretrained vista2d model achieves good performance on diverse set of cell types, microscopy image modalities, and can be further finetuned if necessary. The codebase utilizes several components from other great works including [SegmentAnything](https://github.com/facebookresearch/segment-anything) and [Cellpose](https://www.cellpose.org/), which must be pip installed as dependencies. Vista2D codebase follows MONAI bundle format and its [specifications](https://docs.monai.io/en/stable/mb_specification.html).

<div align="center"> <img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/04/magnified-cells-1.png" width="800"/> </div>

Expand Down
13 changes: 11 additions & 2 deletions models/vista3d/configs/inference_trt.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,17 @@
"+imports": [
"$from monai.networks import trt_compile"
],
"trt_args": {
"max_prompt_size": 8,
"head_trt_enabled": false,
"network_trt_args": {
"dynamic_batchsize": "$[1, @inferer#sw_batch_size, @inferer#sw_batch_size]"
},
"network": "$trt_compile(@network_def.to(@device), @bundle_root + '/models/model.pt', args=@trt_args, submodule=['image_encoder.encoder', 'class_head'])"
"network_dev": "$@network_def.to(@device)",
"encoder": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt', args=@network_trt_args, submodule=['image_encoder.encoder'])",
"head_trt_args": {
"dynamic_batchsize": "$[1, 1, @max_prompt_size]",
"fallback": "$True"
},
"head": "$trt_compile(@network_dev, @bundle_root + '/models/model.pt', args=@head_trt_args, submodule=['class_head']) if @head_trt_enabled else @network_dev",
"network": "$None if @encoder is None else @head"
}
3 changes: 2 additions & 1 deletion models/vista3d/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.4.8",
"version": "0.4.9",
"changelog": {
"0.4.9": "fix oom issue and update readme",
"0.4.8": "use 0.3 overlap for inference",
"0.4.7": "update tensorrt benchmark results",
"0.4.6": "add tensorrt benchmark result and remove the metric part",
Expand Down
2 changes: 1 addition & 1 deletion models/vista3d/configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"early_stop": false,
"fold": 0,
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"epochs": 100,
"epochs": 5,
"val_interval": 1,
"val_at_start": false,
"sw_overlap": 0.625,
Expand Down
2 changes: 1 addition & 1 deletion models/vista3d/configs/train_continual.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"num_classes": 255,
"output_classes": "$len(@label_set)",
"optimizer": {
"_target_": "Novograd",
"_target_": "torch.optim.AdamW",
"lr": "@learning_rate",
"params": "[email protected]()"
},
Expand Down
313 changes: 170 additions & 143 deletions models/vista3d/docs/README.md

Large diffs are not rendered by default.

137 changes: 0 additions & 137 deletions models/vista3d/docs/labels.json

This file was deleted.

2 changes: 2 additions & 0 deletions models/vista3d/large_files.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ large_files:
url: "https://developer.download.nvidia.com/assets/Clara/monai/tutorials/model_zoo/model_vista3d.pt"
hash_val: "6ce45a8edde4400c5d28d5e74d7b61d5"
hash_type: "md5"
- path: "docs/labels.json"
url: "https://github.com/Project-MONAI/tutorials/blob/e66be5955d2b4f5959884ca026932762954b19c5/vista_3d/label_dict.json"
41 changes: 15 additions & 26 deletions models/vista3d/scripts/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch
from monai.apps.vista3d.inferer import point_based_window_inferer
from monai.inferers import Inferer, sliding_window_inference
from monai.inferers import Inferer, SlidingWindowInfererAdapt
from torch import Tensor


Expand All @@ -33,7 +33,6 @@ def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -
self.overlap = overlap
self.sw_batch_size = sw_batch_size
self.use_point_window = use_point_window
self.sliding_window_inferer = point_based_window_inferer if use_point_window else sliding_window_inference

def __call__(
self,
Expand Down Expand Up @@ -62,11 +61,6 @@ def __call__(
prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID!
"""
sliding_window_inferer = (
point_based_window_inferer
if (self.use_point_window and point_coords is not None)
else sliding_window_inference
)
prompt_class = copy.deepcopy(class_vector)
if class_vector is not None:
# Check if network has attribute 'point_head' directly or within its 'module'
Expand All @@ -79,12 +73,14 @@ def __call__(

if torch.any(class_vector > point_head.last_supported):
class_vector = None
if isinstance(inputs, list):
device = inputs[0].device
else:
device = inputs.device
try:
val_outputs = sliding_window_inferer(
val_outputs = None
torch.cuda.empty_cache()
if self.use_point_window and point_coords is not None:
if isinstance(inputs, list):
device = inputs[0].device
else:
device = inputs.device
val_outputs = point_based_window_inferer(
inputs=inputs,
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
Expand All @@ -103,20 +99,13 @@ def __call__(
labels=labels,
label_set=label_set,
)
except Exception:
val_outputs = None
torch.cuda.empty_cache()
val_outputs = sliding_window_inferer(
inputs=inputs,
roi_size=self.roi_size,
sw_batch_size=self.sw_batch_size,
else:
val_outputs = SlidingWindowInfererAdapt(
roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True
)(
inputs,
network,
transpose=True,
with_coord=True,
predictor=network,
mode="gaussian",
sw_device=device,
device="cpu",
overlap=self.overlap,
point_coords=point_coords,
point_labels=point_labels,
class_vector=class_vector,
Expand Down

0 comments on commit 23d2558

Please sign in to comment.