Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Probabilisitic UNet Code #226

Open
wants to merge 302 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
302 commits
Select commit Hold shift + click to select a range
2ef525f
Correction to GECO
pchlap Jun 19, 2021
6b7d8c8
hprob
pchlap Jun 27, 2021
18ab7ee
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Jun 29, 2021
a2bd858
Merge branch 'master' into prob-unet
pchlap Jun 29, 2021
5a4ed42
Use sum for loss
pchlap Jun 29, 2021
7ac7629
Allow configure kappa
pchlap Jun 29, 2021
1afbec4
Correct geco
pchlap Jun 29, 2021
62417aa
Update git attributes with binary
pchlap Jun 29, 2021
85df680
update git attributes
pchlap Jun 29, 2021
bb157ed
update git attributes
pchlap Jun 29, 2021
cdf4064
Converting to pytorch-lightning
pchlap Jun 30, 2021
669eb57
Work on migrating to pytorch lightning
pchlap Jul 1, 2021
cd53ecc
Adding CometML to pytorch lightning
pchlap Jul 1, 2021
1e32258
Correction to param passing
pchlap Jul 1, 2021
a8d7740
Able to gen pseudo data directly
pchlap Jul 1, 2021
0cead6f
Allow configuration of batch size
pchlap Jul 2, 2021
53cf69f
Allow config num workers
pchlap Jul 2, 2021
8ae7a5e
Save slices as npy
pchlap Jul 2, 2021
6fff9e6
correct reading slice
pchlap Jul 2, 2021
e0d1548
Working on validation step
pchlap Jul 2, 2021
e4d99c8
Working on validation
pchlap Jul 2, 2021
e20d4b1
Log images
pchlap Jul 2, 2021
7596e41
Work on prob unet
pchlap Jul 3, 2021
3692ace
Update to pseudo generator
pchlap Jul 3, 2021
6c4eada
validation metrics
pchlap Jul 3, 2021
e31763c
Log metrics during validation
pchlap Jul 3, 2021
7c45cd7
updates to vis
pchlap Jul 4, 2021
fda33a3
Loss on top k percentage
pchlap Jul 4, 2021
5da5ba4
minor corrections
pchlap Jul 4, 2021
3077b8d
Recale CT data properly
pchlap Jul 5, 2021
2856692
Able to parse from json file
pchlap Jul 10, 2021
d1fbb87
Make crop to mm configurable
pchlap Jul 10, 2021
b6e5ecf
Updates to prob unet training
pchlap Jul 11, 2021
62755dc
Adapt range for CT
pchlap Jul 11, 2021
94efb7f
Adjustments to GECO
pchlap Jul 16, 2021
e0fcd72
Allow using pre-augmented data
pchlap Jul 17, 2021
7d7863d
Corrections to using augmented data
pchlap Jul 17, 2021
2ec283d
read data correctly when already generated
pchlap Jul 18, 2021
6fbb317
Correct top_k_percentage mask
pchlap Jul 18, 2021
c19809b
Making target spacing configurable
pchlap Jul 19, 2021
cac5578
Improvements to deformable augmentation tool
pchlap Jul 20, 2021
fbd1c03
compute loss in contour area
pchlap Jul 25, 2021
aad9eb0
send config to comet ml
pchlap Jul 25, 2021
47fd5e2
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Jul 25, 2021
803ddec
train prob unet rec initially
pchlap Jul 26, 2021
a490e97
Don't separate out directories for working files
pchlap Jul 31, 2021
2f1ed50
Correct check in union intersection masks
pchlap Jul 31, 2021
0517f98
make contour mask kernel size configurable
pchlap Jul 31, 2021
0862108
Add contour_mask_kernel to class attributes
pchlap Jul 31, 2021
b365933
Extend probnet to 3D
pchlap Aug 13, 2021
0db479b
Prepare nnUnet service example
pchlap Aug 13, 2021
810fe5a
Add updated dataset code
pchlap Aug 13, 2021
3762bc3
Add localiser network
pchlap Aug 14, 2021
6eebade
Trainer to support 3d
pchlap Aug 14, 2021
0a5644d
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Aug 14, 2021
f9fca18
run inference on localise model
pchlap Aug 15, 2021
a88a948
Use localise network to preprocess data
pchlap Aug 15, 2021
b19d76d
Load checkpoint
pchlap Aug 15, 2021
dd0aee3
Separate out localise net and train
pchlap Aug 15, 2021
7e2ae2a
Separate out utils
pchlap Aug 15, 2021
e18291d
Move out some util functions
pchlap Aug 16, 2021
f96abb4
Crop to roi for synth DVF generation
pchlap Aug 16, 2021
6158d56
Merge branch 'synth-dvf-speedup' into prob-unet
pchlap Aug 16, 2021
8d50c50
handle int or list for expand properly
pchlap Aug 16, 2021
af89177
Merge branch 'synth-dvf-speedup' into prob-unet
pchlap Aug 16, 2021
97b9cf8
Use roi expand
pchlap Aug 16, 2021
6f0c1b0
Merge branch 'synth-dvf-speedup' into prob-unet
pchlap Aug 16, 2021
2acff8e
Run localise model as first step of prob unet
pchlap Aug 17, 2021
708e725
More control over image preprocessing
pchlap Aug 18, 2021
a82ac42
Ensure min size of expand roi
pchlap Aug 18, 2021
c77ccf1
Merge branch 'synth-dvf-speedup' into prob-unet
pchlap Aug 18, 2021
4b46e99
Update to augmentation script
pchlap Aug 19, 2021
3c88ee0
Use cropping during augmentation
pchlap Aug 19, 2021
9ef7266
crop before fill holes
pchlap Aug 19, 2021
de0a6fd
Merge branch 'nnunet-service' into prob-unet
pchlap Aug 19, 2021
eb1ed88
Fixes and stuff
pchlap Aug 22, 2021
1da91ae
Work on prob net infer
pchlap Aug 23, 2021
8f8e57e
Add validation function
pchlap Aug 24, 2021
c6095d3
Add coarse dropout to augmentation
pchlap Aug 25, 2021
23c3701
Correct issue in train
pchlap Aug 25, 2021
58b1e4c
Correct circular import
pchlap Aug 25, 2021
aee092b
Able to run on gpu
pchlap Aug 25, 2021
c1cbf63
Correct visualisation
pchlap Aug 25, 2021
dd10015
3d augmentations
pchlap Aug 27, 2021
e70b968
correct bug
pchlap Aug 27, 2021
4363bc5
Correct median blur
pchlap Aug 27, 2021
4eb0ac2
Merge branch 'master' into prob-unet
pchlap Aug 30, 2021
abdc122
Adjust window level in validation
pchlap Aug 31, 2021
4a703fe
Build contour loss into loss function
pchlap Aug 31, 2021
800090f
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Aug 31, 2021
a09c1c9
Correct issue
pchlap Aug 31, 2021
9b39396
Correct another issue
pchlap Aug 31, 2021
6d3132d
Ensure moved to gpu
pchlap Aug 31, 2021
586ca47
Ensure contour loss can be turned off
pchlap Aug 31, 2021
25108ff
Separate results for contour loss
pchlap Aug 31, 2021
94b7d07
Correct issue
pchlap Sep 1, 2021
744bba9
Able to compute loss in multiple masks
pchlap Sep 1, 2021
e7078b9
Correction to converting image
pchlap Sep 1, 2021
94948cd
Fix to prob unet
pchlap Sep 4, 2021
c4f280b
Update to fold path
pchlap Sep 5, 2021
546cbb4
Use mean to compute rec loss
pchlap Sep 5, 2021
5e3f2ad
Correction to include contour loss
pchlap Sep 8, 2021
dfa6d46
Hierarchical probabilistic UNet work
pchlap Sep 15, 2021
38c7e2a
format visualisation
pchlap Sep 15, 2021
efd8e3c
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Sep 15, 2021
05ca2b4
Update to hpunet test
pchlap Sep 15, 2021
dd81b37
Fix double argument
pchlap Sep 18, 2021
b383702
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Sep 18, 2021
e537466
Working on hprobunet
pchlap Sep 18, 2021
fe45869
Separate clamp for contour loss and rec loss
pchlap Sep 21, 2021
26530a6
Tests for prob UNet
pchlap Sep 21, 2021
410ffd4
Update to gitignore
pchlap Sep 21, 2021
8369bc6
Correct call for hprob to loss
pchlap Sep 21, 2021
c3f3816
Corrections to hprob net
pchlap Sep 22, 2021
40f30ff
Correction to train
pchlap Sep 22, 2021
781bf49
Work on hierarchical probabilistic UNet
pchlap Oct 6, 2021
840cf9c
Add dropout to prob unet
pchlap Oct 6, 2021
507f6f7
Add dropout
pchlap Oct 12, 2021
d2d94b0
Merge branch 'master' into prob-unet
pchlap Oct 13, 2021
d4c7f3f
Merge branch 'compress-nrrd' into prob-unet
pchlap Oct 17, 2021
4bb8d86
Working towards supporting multiple structures
pchlap Nov 9, 2021
882ee8d
Localise Net support multiple structures use case
pchlap Nov 9, 2021
b5a4683
Work towards supporting multiple structures in prob UNet
pchlap Nov 10, 2021
ea31bd8
Working on supporting multiple structures
pchlap Nov 10, 2021
6e0d76a
Update attribute name
pchlap Nov 11, 2021
7b4a273
Fix issues with validating multiple structures
pchlap Nov 11, 2021
e34fee6
Add sigmoid to output
pchlap Nov 14, 2021
89450a4
Trac pos weight
pchlap Nov 16, 2021
a0c1151
Merge branch 'master' into prob-unet
pchlap Dec 7, 2021
4da54a2
Update to multiple structure training
pchlap Dec 20, 2021
1ae54f4
Augmentation of fail cases using DVF
pchlap Jan 18, 2022
aa443b9
Merge branch 'master' into aug-fail-cases-gen
pchlap Feb 28, 2022
42bfc46
Work on DVF augmentation
pchlap Feb 28, 2022
5151e90
Merge in augmentation work
pchlap Feb 28, 2022
de27b84
Add contract from target function
pchlap Mar 3, 2022
1b0f692
Merge branch 'master' into prob-unet
pchlap Apr 3, 2022
b3b282b
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Apr 17, 2022
c3b9c98
Merge in augfail
pchlap Apr 17, 2022
9f99fe9
Corrections
pchlap Apr 26, 2022
d81108b
merge main
pchlap Apr 26, 2022
200922c
change default dropout
pchlap May 30, 2022
a3f1ed8
Add code to process LIDC dataset
pchlap Jun 6, 2022
b36788d
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Jun 6, 2022
93addf6
zero if both masks empty
pchlap Jun 7, 2022
4ea3674
Merge branch 'master' into prob-unet
pchlap Jun 7, 2022
3f5ff14
Visualise image on validate
pchlap Jun 7, 2022
7f2bb0f
Fix dataloading
pchlap Jun 7, 2022
c2537a0
LIDC tweaks
pchlap Jun 13, 2022
1edbc18
Correct issue in training step
pchlap Aug 3, 2022
9f1c295
Work on argmax
pchlap Aug 3, 2022
259f5c7
Correct argmax
pchlap Aug 3, 2022
e75e8b9
Change to rec loss summation
pchlap Aug 8, 2022
2824640
Used GED for main train function
pchlap Aug 8, 2022
4880422
Add missing import
pchlap Aug 9, 2022
35508d8
A few prob UNET corrections
pchlap Aug 25, 2022
fa6a30b
Merge branch 'master' into prob-unet
pchlap Aug 31, 2022
4a038ec
Corrections to prob-unet
pchlap Sep 2, 2022
d73467b
Update to train
pchlap Sep 6, 2022
e36acd4
Correct code
pchlap Sep 6, 2022
afd3cd3
set geco step size
pchlap Sep 8, 2022
58975c8
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Sep 8, 2022
007d043
Remove acc grad
pchlap Sep 8, 2022
663ad41
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Sep 8, 2022
81dbd9f
adjustments to prob unet
pchlap Sep 15, 2022
09b704d
Add psDSC metric
pchlap Sep 22, 2022
be80af7
add images to gitignore
pchlap Sep 22, 2022
b322793
Use reduce lr on plateu schedule
pchlap Sep 22, 2022
e77274a
Att early stopping
pchlap Sep 22, 2022
fb2e003
Be more patient
pchlap Sep 23, 2022
dc403c3
Increace patience
pchlap Sep 24, 2022
b007d0f
Don't early stop
pchlap Sep 24, 2022
f213daa
Reenable early stopping
pchlap Sep 27, 2022
43d5019
Actually enable early stopping
pchlap Sep 28, 2022
9c82524
Remove KL clamp
pchlap Sep 29, 2022
1fbe811
Add KL clamp back in
pchlap Sep 29, 2022
9b0f725
Add in dropout probability
pchlap Oct 1, 2022
27a5b77
Turn off dropout for prob part
pchlap Oct 3, 2022
13a6549
Reenable weight agg
pchlap Oct 7, 2022
c2850b3
Merge branch 'master' into prob-unet
pchlap Oct 10, 2022
5ccfab6
Ensure dropout layers are there even if not used
pchlap Oct 17, 2022
b535926
Merge branch 'prob-unet' of github.com:pyplati/platipy into prob-unet
pchlap Oct 17, 2022
72a67fc
Adjustments for early stopping
pchlap Jan 5, 2023
89adbe7
Fix dataset bug
pchlap Jan 5, 2023
a4becf1
Remove debug statements
pchlap Jan 5, 2023
cb4f1b7
Adjust patience
pchlap Jan 6, 2023
0a55d01
Ensure lambda is below zero before early stop
pchlap Jan 8, 2023
9f47ae2
Early stopping tweaks
pchlap Jan 8, 2023
152449e
Cycle learning rate
pchlap Jan 8, 2023
8ddd2b1
Don't cycle momentum
pchlap Jan 8, 2023
1088363
change LR step size
pchlap Jan 9, 2023
19a9ce2
Faster cycle LR
pchlap Jan 9, 2023
9c7e842
Merge branch 'master' into prob-unet
pchlap Jan 13, 2023
6ee7585
Adjustments to learning rate
pchlap Jan 13, 2023
c8b7a09
change batch agg
pchlap Jan 13, 2023
2e6d8d1
Adjust LR schedule
pchlap Jan 18, 2023
bf087bf
Merge branch 'master' into prob-unet
pchlap Jan 31, 2023
f192858
Use Cosine annealing LR
pchlap Mar 15, 2023
069b024
Merge branch 'master' into prob-unet
pchlap May 19, 2023
51c6df2
Support additional training data
pchlap May 22, 2023
673ce65
Merge branch 'master' into prob-unet
pchlap Nov 27, 2023
d75210a
Remove loguru imports
pchlap Nov 27, 2023
5bd7b4d
Support context map inputs to prob UNet
pchlap Dec 4, 2023
d95b1c3
Support context map for augmentation
pchlap Dec 4, 2023
dac6f70
Fix bug
pchlap Dec 4, 2023
0de77d5
augment log to std out
pchlap Dec 4, 2023
eeda24d
Add missing import
pchlap Dec 4, 2023
fe45f69
Fix issues
pchlap Dec 6, 2023
11f2ccd
Add context map for validation and test cases
pchlap Dec 6, 2023
604985a
Resolve ambigious truth value
pchlap Dec 6, 2023
c8ec57e
Correct typo
pchlap Dec 6, 2023
0ddb078
Add missing unsqueeze
pchlap Dec 6, 2023
acad3bb
Resolve ambigous truth value
pchlap Dec 6, 2023
2e6aded
Add missing import
pchlap Dec 6, 2023
45e4bc1
update deprecated cmap call
pchlap Dec 6, 2023
fb74648
Add missing import
pchlap Dec 6, 2023
83550f4
Check none context map glob
pchlap Dec 6, 2023
5549bed
Deal with none cmap
pchlap Dec 6, 2023
0624940
Fix missing augmentations for context map
pchlap Dec 7, 2023
fb47e1a
Add missing return value
pchlap Dec 7, 2023
d65ec7a
init variable
pchlap Dec 7, 2023
f3d6a99
Read the label map in nnUNet service
pchlap Dec 14, 2023
2890943
Save the label mask in the loop
pchlap Dec 14, 2023
48e8ba3
Add HD 95 metric
pchlap Dec 20, 2023
34e2ebb
Update docstring
pchlap Dec 20, 2023
a128cea
Merge branch 'add-hd-95-metric' into prob-unet
pchlap Dec 21, 2023
b56a876
Allow inference using a reference segmentation
pchlap Apr 1, 2024
080aba6
Fix variable passed
pchlap Apr 1, 2024
f20acf6
Pass img
pchlap Apr 1, 2024
ec18dbc
Add background channel
pchlap Apr 1, 2024
7373a2a
Fix var name
pchlap Apr 1, 2024
e9e5a14
Experiment with using structure context
pchlap Apr 14, 2024
c71b950
Use int instead of bool
pchlap Apr 14, 2024
f3bc014
Fix None prior
pchlap Apr 15, 2024
06b7ec6
Fix missing var name
pchlap Apr 15, 2024
8e0491c
Tweak probunet input
pchlap Apr 15, 2024
a1f98aa
cat in probunet
pchlap Apr 15, 2024
a3e775b
Dont set in train loop
pchlap Apr 15, 2024
89c34fe
Correct unit dist
pchlap Apr 15, 2024
eb42774
Move dist to device
pchlap Apr 15, 2024
0a76cb9
Fix move to device
pchlap Apr 15, 2024
926825b
Fix move to device
pchlap Apr 15, 2024
3c12dd3
Fix bracket
pchlap Apr 15, 2024
c908e85
Cat context during validation
pchlap Apr 16, 2024
09346ab
Pass seg during validation
pchlap Apr 16, 2024
967300f
Pass the seg during inference
pchlap Apr 19, 2024
c4daf99
Debug probunet with multi channels
pchlap May 13, 2024
809d501
Prep for multi input experiment
pchlap May 13, 2024
4a8e372
staple mask for validation
pchlap May 14, 2024
de6628e
Allow providing context structure
pchlap Sep 26, 2024
2b9b5e0
Context based inference
pchlap Nov 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improvements to deformable augmentation tool
pchlap committed Jul 20, 2021

Verified

This commit was signed with the committer’s verified signature.
MichaelHatherly Michael Hatherly
commit cac5578d7d57906d1eb16376ce95eaabaae18680
228 changes: 198 additions & 30 deletions platipy/imaging/generation/augment.py
Original file line number Diff line number Diff line change
@@ -16,8 +16,16 @@
from collections.abc import Iterable
import random

from pathlib import Path

from argparse import ArgumentParser

import SimpleITK as sitk

from loguru import logger

from platipy.imaging import ImageVisualiser

from platipy.imaging.generation.dvf import (
generate_field_shift,
generate_field_expand,
@@ -83,39 +91,21 @@ def apply_augmentation(image, augmentation, masks=[]):
return image_deformed, dvf


def generate_random_augmentation(ct_image, masks):

random.shuffle(masks)
# mask_count = len(masks)
# masks = masks[: random.randint(2, 5)]

# print(len(masks))
augmentation_types = [
{
"class": ShiftAugment,
"args": {"vector_shift": [(-10, 10), (10, 10), (-10, 10)], "gaussian_smooth": (3, 5)},
},
{
"class": ContractAugment,
"args": {
"vector_contract": [(0, 10), (0, 10), (0, 10)],
"gaussian_smooth": (3, 5),
"bone_mask": True,
},
},
{
"class": ExpandAugment,
"args": {
"vector_expand": [(0, 10), (0, 10), (0, 10)],
"gaussian_smooth": (3, 5),
"bone_mask": True,
},
},
]
def generate_random_augmentation(ct_image, masks, augmentation_types):

augmentation = []

probabilities = [a["probability"] for a in augmentation_types]
prob_total = sum(probabilities)
prob_none = 1.0 - prob_total
if prob_none < 0:
prob_none = 0

for mask in masks:
aug = random.choice(augmentation_types)
aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0]

if aug is None:
continue

aug_class = aug["class"]
aug_args = {}
@@ -203,3 +193,181 @@ def augment(self):
gaussian_smooth=self.gaussian_smooth,
)
return transform, dvf


def augment_data(args):

random.seed(args.seed)

augmentation_types = []

if args.enable_shift:
augmentation_types.append(
{
"class": ShiftAugment,
"args": {
"vector_shift": [
tuple(args.shift_x_range),
tuple(args.shift_y_range),
tuple(args.shift_z_range),
],
"gaussian_smooth": tuple(args.shift_smooth_range),
},
"probability": args.shift_probability,
}
)

if args.enable_contract:
augmentation_types.append(
{
"class": ContractAugment,
"args": {
"vector_contract": [
tuple(args.contract_x_range),
tuple(args.contract_y_range),
tuple(args.contract_z_range),
],
"gaussian_smooth": tuple(args.contract_smooth_range),
"bone_mask": args.contract_bone_mask,
},
"probability": args.contract_probability,
}
)

if args.enable_expand:
augmentation_types.append(
{
"class": ExpandAugment,
"args": {
"vector_expand": [
tuple(args.expand_x_range),
tuple(args.expand_y_range),
tuple(args.expand_z_range),
],
"gaussian_smooth": tuple(args.expand_smooth_range),
"bone_mask": args.expand_bone_mask,
},
"probability": args.expand_probability,
}
)

data_dir = Path(args.data_dir)
output_dir = Path(args.output_dir)

cases = [
p.name.replace(".nii.gz", "")
for p in data_dir.glob(args.case_glob)
if not p.name.startswith(".")
]
cases.sort()

data = {
case: {
"image": data_dir.joinpath(args.image_glob.format(case=case)),
"label": [p for p in data_dir.glob(args.label_glob.format(case=case))],
}
for case in cases
}

for case in cases:

logger.info(f"Augmenting for case: {case}")

ct_image = sitk.ReadImage(str(data[case]["image"]))

# Get list of structures to generate augmentations off
all_masks = []
all_names = []
for structure_path in data[case]["label"]:

mask = sitk.ReadImage(str(structure_path))

all_masks.append(mask)
all_names.append(structure_path.name.replace(".nii.gz", ""))

# Generate 10 random augmentations per case
for i in range(args.augmentations_per_case):

logger.debug("Generating augmentation")

augmented_case_path = output_dir.joinpath(case, f"augment_{i}")
augmented_case_path.mkdir(exist_ok=True, parents=True)

augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types)

dvf = None

if len(augmentation) == 0:
logger.debug(
"No augmentations generated, generated image won't differ from original"
)

augmented_image = ct_image
augmented_masks = all_masks
else:

logger.debug("Applying augmentation")
augmented_image, augmented_masks, dvf = apply_augmentation(
ct_image, augmentation, masks=all_masks
)

augmented_image_path = augmented_case_path.joinpath("CT.nii.gz")
sitk.WriteImage(augmented_image, str(augmented_image_path))

vis = ImageVisualiser(image=ct_image, figure_size_in=6)
vis.add_comparison_overlay(augmented_image)
if dvf is not None:
vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12))
for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks):
vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask})

logger.debug(f"Applying augmentation to mask: {mask_name}")
augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz")
sitk.WriteImage(augmented_mask, str(augmented_mask_path))

fig = vis.show()

figure_path = augmented_case_path.joinpath("aug.png")
fig.savefig(figure_path, bbox_inches="tight")


if __name__ == "__main__":

arg_parser = ArgumentParser()
arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed")
arg_parser.add_argument("--data_dir", type=str, default="./data")
arg_parser.add_argument("--output_dir", type=str, default="./augment")
arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz")
arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz")
arg_parser.add_argument("--label_glob", type=str, default="labels/{case}_*.nii.gz")
arg_parser.add_argument(
"--augmentations_per_case",
type=int,
default=10,
help="How many augmented images per case to generate",
)

arg_parser.add_argument("--enable_shift", type=bool, default=True)
arg_parser.add_argument("--shift_x_range", nargs="+", type=int, default=[-10, 10])
arg_parser.add_argument("--shift_y_range", nargs="+", type=int, default=[-10, 10])
arg_parser.add_argument("--shift_z_range", nargs="+", type=int, default=[-10, 10])
arg_parser.add_argument("--shift_smooth_range", nargs="+", type=int, default=[3, 5])
arg_parser.add_argument("--shift_probability", type=float, default=0.5)

arg_parser.add_argument("--enable_expand", type=bool, default=True)
arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5])
arg_parser.add_argument("--expand_bone_mask", type=bool, default=True)
arg_parser.add_argument("--expand_probability", type=float, default=0.5)

arg_parser.add_argument("--enable_contract", type=bool, default=True)
arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5])
arg_parser.add_argument("--contract_bone_mask", type=bool, default=True)
arg_parser.add_argument("--contract_probability", type=float, default=0.5)

augment_data(arg_parser.parse_args())