Skip to content

Commit

Permalink
Support additional training data
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed May 22, 2023
1 parent 069b024 commit 51c6df2
Showing 1 changed file with 98 additions and 0 deletions.
98 changes: 98 additions & 0 deletions platipy/imaging/cnn/dataload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ class UNetDataModule(pl.LightningDataModule):
def __init__(
self,
data_dir: str = "./data",
data_add_dirs: list = [],
augmented_dir: str = None,
augmented_add_dirs: list = [],
working_dir: str = "./working",
structures=["a", "b", "c"],
observers=["0", "1", "2", "3", "4"],
observers_add=[],
case_glob="images/*.nii.gz",
image_glob="images/{case}.nii.gz",
label_glob="labels/{case}_{structure}_*.nii.gz",
label_add_glob="labels/{case}_{structure}.nii.gz",
augmented_case_glob="{case}/*",
augmented_image_glob="images/{augmented_case}.nii.gz",
augmented_label_glob="labels/{augmented_case}_{structure}_*.nii.gz",
augmented_label_add_glob="labels/{augmented_case}_{structure}_*.nii.gz",
augment_on_fly=True,
fold=0,
k_folds=5,
Expand All @@ -47,15 +52,19 @@ def __init__(
):
super().__init__()
self.data_dir = Path(data_dir)
self.data_add_dirs = [Path(p) for p in data_add_dirs]
self.augmented_dir = augmented_dir
self.augmented_add_dirs = augmented_add_dirs
self.working_dir = Path(working_dir)

self.case_glob = case_glob
self.image_glob = image_glob
self.label_glob = label_glob
self.label_add_glob = label_add_glob
self.augmented_case_glob = augmented_case_glob
self.augmented_image_glob = augmented_image_glob
self.augmented_label_glob = augmented_label_glob
self.augmented_label_add_glob = augmented_label_add_glob

self.augment_on_fly = augment_on_fly
self.fold = fold
Expand All @@ -75,6 +84,7 @@ def __init__(
self.contour_mask_kernel = contour_mask_kernel
self.structures = structures
self.observers = observers
self.observers_add = observers_add

self.crop_using_localise_model = crop_using_localise_model
self.localise_voxel_grid_size = localise_voxel_grid_size
Expand All @@ -96,22 +106,29 @@ def add_model_specific_args(parent_parser):
"""Add arguments used for Data module"""
parser = parent_parser.add_argument_group("Data Loader")
parser.add_argument("--data_dir", type=str, default="./data")
parser.add_argument("--data_add_dirs", nargs="+", type=str, default=[])
parser.add_argument("--augmented_dir", type=str, default=None)
parser.add_argument("--augmented_add_dirs", nargs="+", type=str, default=[])
parser.add_argument("--augment_on_fly", type=bool, default=True)
parser.add_argument("--fold", type=int, default=0)
parser.add_argument("--k_folds", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=5)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--structures", nargs="+", type=str, default=["a", "b", "c"])
parser.add_argument("--observers", nargs="+", type=str, default=["0", "1", "2", "3", "4"])
parser.add_argument("--observers_add", nargs="+", type=str, default=[])
parser.add_argument("--case_glob", type=str, default="images/*.nii.gz")
parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz")
parser.add_argument(
"--label_glob", type=str, default="labels/{case}_{structure}_{observer}.nii.gz"
)
parser.add_argument(
"--label_add_glob", type=str, default="labels/{case}_{structure}.nii.gz"
)
parser.add_argument("--augmented_case_glob", type=str, default=None)
parser.add_argument("--augmented_image_glob", type=str, default=None)
parser.add_argument("--augmented_label_glob", type=str, default=None)
parser.add_argument("--augmented_label_add_glob", type=str, default=None)
parser.add_argument("--crop_to_grid_size_xy", type=int, default=128)
parser.add_argument("--intensity_scaling", type=str, default="window")
parser.add_argument("--intensity_window", nargs="+", type=int, default=[-500, 500])
Expand Down Expand Up @@ -209,6 +226,87 @@ def setup(self, stage=None):
for augmented_case in augmented_cases
]

# If observers_add is empty then just add one dummy observer since they are not using
# Multi observer data here
if len(self.observers_add) == 0:
self.observers_add = ["X"]

# Add in the addtional cases, these are only use for training and may only have 1 observer
for data_add_dir in self.data_add_dirs:
self.add_train_cases = []
cases = [
p.name.replace(".nii.gz", "")
for p in data_add_dir.glob(self.case_glob)
if not p.name.startswith(".")
]
self.add_train_cases += cases
train_data += [
{
"id": case,
"image": data_add_dir.joinpath(self.image_glob.format(case=case)),
"observers": {
observer: {
structure: data_add_dir.joinpath(
self.label_add_glob.format(
case=case, structure=structure, observer=observer
)
)
for structure in self.structures
}
for observer in self.observers_add
},
}
for case in cases
]

for case in cases:

case_aug_dir = None
for aug_add_dir in self.augmented_add_dirs:
if Path(aug_add_dir.format(case=case)).exists():

case_aug_dir = Path(aug_add_dir.format(case=case))
else:
print(f"No dir {Path(aug_add_dir.format(case=case))}")

if case_aug_dir is None:
continue

augmented_cases = [
p.name.replace(".nii.gz", "")
for p in case_aug_dir.glob(self.augmented_case_glob.format(case=case))
if not p.name.startswith(".")
]
print(augmented_cases)

train_data += [
{
"id": f"{case}_{augmented_case}",
"image": case_aug_dir.joinpath(
self.augmented_image_glob.format(
case=case, augmented_case=augmented_case
)
),
"observers": {
observer: {
structure: case_aug_dir.joinpath(
self.augmented_label_add_glob.format(
case=case,
augmented_case=augmented_case,
structure=structure,
observer=observer
)
)
for structure in self.structures
}
for observer in self.observers_add
},
}
for augmented_case in augmented_cases
]
print(train_data)
print(len(train_data))

self.validation_data = [
{
"id": case,
Expand Down

0 comments on commit 51c6df2

Please sign in to comment.