Skip to content

Commit

Permalink
General QOL improvements (#143)
Browse files Browse the repository at this point in the history
* preemptive bump

* Add 502 and sorted eval

* Remove large keys from Planner, add more documentation and FLOP measurement

* .

* Biiiiig commit

* A lot of minor changes

* add wandb log image

* add fvcore

* Fix log

* moved key dropping to plan config

* Lint and minor changes

* more lint and add aug param dict

* add plans to return

* add empty list

* Fix bug with way too large pickle files

* add mask value to aug params

* add new args

* Remove abs call

* Minor changes to test augs

* updates

* add cval and update illustrations

* Add F1

* Add F1

* add F1 to val

* add self

* remove unused import
  • Loading branch information
Sllambias authored Mar 15, 2024
1 parent 5b94099 commit 69d0317
Show file tree
Hide file tree
Showing 28 changed files with 744 additions and 177 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ dependencies = [
"weave",
"python-dotenv==1.0.0",
"monai",
"flake8-unused-arguments"
"flake8-unused-arguments",
"surface-distance-based-measures@git+https://github.com/google-deepmind/surface-distance",
"fvcore",
]

[project.optional-dependencies]
Expand Down
264 changes: 170 additions & 94 deletions yucca/documentation/illustrations/augmentation_illustrations.ipynb

Large diffs are not rendered by default.

33 changes: 25 additions & 8 deletions yucca/evaluation/YuccaEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
auroc,
)
from yucca.evaluation.obj_metrics import get_obj_stats_for_label
from yucca.evaluation.surface_metrics import get_surface_metrics_for_label
from yucca.paths import yucca_raw_data
from weave.monitoring import StreamTable
from tqdm import tqdm
Expand All @@ -32,13 +33,20 @@ def __init__(
self,
labels: list | int,
folder_with_predictions,
use_wandb: bool,
folder_with_ground_truth,
do_object_eval=False,
use_wandb: bool,
as_binary=False,
do_object_eval=False,
do_surface_eval=False,
overwrite: bool = False,
task_type: Literal["segmentation", "classification", "regression"] = "segmentation",
):
self.name = "results"

self.overwrite = overwrite
self.use_wandb = use_wandb
self.task_type = task_type

self.metrics = {
"Dice": dice,
"Jaccard": jaccard,
Expand All @@ -52,9 +60,7 @@ def __init__(
"Total Positives Prediction": total_pos_pred,
}
self.obj_metrics = []

self.use_wandb = use_wandb
self.task_type = task_type
self.surface_metrics = []

if self.task_type == "segmentation":
self.metrics = {
Expand Down Expand Up @@ -85,6 +91,12 @@ def __init__(
"_OBJ F1",
]

if do_surface_eval:
self.name += "_SURFACE"
self.surface_metrics = [
"Average Surface Distance",
]

self.metrics_included_in_streamtable = [
"Dice",
"Sensitivity",
Expand Down Expand Up @@ -130,7 +142,7 @@ def __init__(
self.labels = ["0", "1"]
self.name += "_BINARY"

self.labelarr = np.array(self.labels, dtype=np.uint8)
self.labelarr = np.sort(np.array(self.labels, dtype=np.uint8))
self.folder_with_predictions = folder_with_predictions
self.folder_with_ground_truth = folder_with_ground_truth

Expand Down Expand Up @@ -168,7 +180,7 @@ def sanity_checks(self):
print(f"Labels found in dataset.json: {list(dataset_json['labels'].keys())}")

def run(self):
if isfile(self.outpath):
if isfile(self.outpath) and not self.overwrite:
print(f"Evaluation file already present in {self.outpath}. Skipping.")
else:
self.sanity_checks()
Expand Down Expand Up @@ -271,7 +283,7 @@ def _evaluate_folder_segm(self):
meandict = {}

for label in self.labels:
meandict[label] = {k: [] for k in list(self.metrics.keys()) + self.obj_metrics}
meandict[label] = {k: [] for k in list(self.metrics.keys()) + self.obj_metrics + self.surface_metrics}

for case in tqdm(self.pred_subjects, desc="Evaluating"):
casedict = {}
Expand Down Expand Up @@ -312,6 +324,11 @@ def _evaluate_folder_segm(self):
labeldict[k] = round(v, 4)
meandict[str(label)][k].append(labeldict[k])

if self.surface_metrics:
surface_labeldict = get_surface_metrics_for_label(gt, pred, label, as_binary=self.as_binary)
for k, v in surface_labeldict.items():
labeldict[k] = round(v, 4)
meandict[str(label)][k].append(labeldict[k])
casedict[str(label)] = labeldict
casedict["Prediction:"] = predpath
casedict["Ground Truth:"] = gtpath
Expand Down
31 changes: 31 additions & 0 deletions yucca/evaluation/surface_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np
from yucca.utils.nib_utils import get_nib_spacing
from surface_distance import metrics


def get_surface_metrics_for_label(gt, pred, label, as_binary: bool = False):
spacing = get_nib_spacing(pred)
pred = pred.get_fdata()
gt = gt.get_fdata()
labeldict = {}

if label == 0:
labeldict["Average Surface Distance"] = 0
return labeldict
if as_binary:
gt = gt.astype(bool)
pred = pred.astype(bool)
else:
pred = np.where(pred == label, 1, 0).astype(bool)
gt = np.where(gt == label, 1, 0).astype(bool)

surface_distances = metrics.compute_surface_distances(
mask_gt=gt,
mask_pred=pred,
spacing_mm=spacing,
)

labeldict["Average Surface Distance"] = metrics.compute_surface_dice_at_tolerance(
surface_distances=surface_distances, tolerance_mm=1
)
return labeldict
10 changes: 10 additions & 0 deletions yucca/evaluation/training_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torchmetrics.classification import MulticlassF1Score
from torch import Tensor


class F1(MulticlassF1Score):
def forward(self, input: Tensor, target: Tensor) -> Tensor:
if len(target.shape) == len(input.shape):
assert target.shape[1] == 1
target = target[:, 0]
return super().forward(input, target)
3 changes: 3 additions & 0 deletions yucca/image_processing/transforms/Ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def get_params(alpha: Tuple[float], numReps: Tuple[float], axes: Tuple[float]) -
return alpha, numReps, axis

def __motionGhosting__(self, imageVolume, alpha, numReps, axis):
m = min(0, imageVolume.min())
imageVolume += abs(m)
if len(imageVolume.shape) == 3:
assert axis in [0, 1, 2], "Incorrect or no axis"

Expand All @@ -51,6 +53,7 @@ def __motionGhosting__(self, imageVolume, alpha, numReps, axis):
else:
imageVolume[:, 0:-1:numReps] = alpha * imageVolume[:, 0:-1:numReps]
imageVolume = abs(np.fft.ifftn(imageVolume, s=[h, w]))
imageVolume -= m
return imageVolume

def __call__(self, packed_data_dict=None, **unpacked_data_dict):
Expand Down
28 changes: 21 additions & 7 deletions yucca/image_processing/transforms/Masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
data_key="image",
ratio: Union[Iterable[float], float] = 0.25,
token_size: Union[Iterable[Union[float, int]], float, int] = 0.05,
pixel_value: Union[float, int] = 0,
pixel_value: Union[float, int, str] = 0,
):
self.mask = mask
self.data_key = data_key
Expand All @@ -42,7 +42,15 @@ def __init__(
self.pixel_value = pixel_value

@staticmethod
def get_params(input_shape, ratio, token_size):
def get_params(input_shape, data, pixel_value, ratio, token_size):
if isinstance(pixel_value, str):
assert pixel_value in ["min", "max"]
if pixel_value == "min":
pixel_value = data.min()
elif pixel_value == "max":
pixel_value = data.max()
else:
print(f"unrecognized pixel value: got {pixel_value}")
if isinstance(ratio, (tuple, list)):
# If ratio is a list/tuple it's a range of ratios from which me sample uniformly per batch
ratio = np.random.uniform(*ratio)
Expand All @@ -65,9 +73,9 @@ def get_params(input_shape, ratio, token_size):
"token_size is set to a ratio over 0.25 of the image. " "This is not intended and should be reconsidered."
)
token_size = [int(i * np.random.uniform(*token_size)) for i in input_shape]
return ratio, token_size
return pixel_value, ratio, token_size

def __mask__(self, image_volume, ratio, token_size):
def __mask__(self, image_volume, pixel_value, ratio, token_size):
assert len(image_volume.shape[2:]) == len(token_size), (
"mask token size not compatible with input data"
f"mask token is: {token_size} and image is shape: {image_volume.shape[2:]}"
Expand All @@ -84,7 +92,7 @@ def __mask__(self, image_volume, ratio, token_size):
for idx, size in enumerate(token_size):
grid = np.repeat(grid, repeats=size, axis=idx)

image_volume[:, :, grid[*slices] == 0] = self.pixel_value
image_volume[:, :, grid[*slices] == 0] = pixel_value
return image_volume

def __call__(self, packed_data_dict=None, **unpacked_data_dict):
Expand All @@ -94,6 +102,12 @@ def __call__(self, packed_data_dict=None, **unpacked_data_dict):
), f"Incorrect data size or shape.\
\nShould be (b, c, x, y, z) or (b, c, x, y) and is: {data_dict[self.data_key].shape}"
if self.mask:
ratio, token_size = self.get_params(data_dict[self.data_key].shape[2:], self.ratio, self.token_size)
data_dict[self.data_key] = self.__mask__(data_dict[self.data_key], ratio, token_size)
pixel_value, ratio, token_size = self.get_params(
data_dict[self.data_key].shape[2:],
data_dict[self.data_key],
self.pixel_value,
self.ratio,
self.token_size,
)
data_dict[self.data_key] = self.__mask__(data_dict[self.data_key], pixel_value, ratio, token_size)
return data_dict
3 changes: 3 additions & 0 deletions yucca/image_processing/transforms/Ringing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def get_params(cutFreq, axes):
return cutFreq, axis

def __gibbsRinging__(self, imageVolume, numSample, axis):
m = min(0, imageVolume.min())
imageVolume += abs(m)
if len(imageVolume.shape) == 3:
assert axis in [0, 1, 2], "Incorrect or no axis"

Expand Down Expand Up @@ -54,6 +56,7 @@ def __gibbsRinging__(self, imageVolume, numSample, axis):
imageVolume[:, int(np.ceil(h / 2) + np.ceil(numSample / 2)) : h] = 0
imageVolume = abs(np.fft.ifftn(np.fft.ifftshift(imageVolume), s=[w, h]))
imageVolume = imageVolume.conj().T
imageVolume -= m
return imageVolume

def __call__(self, packed_data_dict=None, **unpacked_data_dict):
Expand Down
9 changes: 8 additions & 1 deletion yucca/image_processing/transforms/Spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(
data_key="image",
label_key="label",
crop=False,
cval="min",
patch_size: Tuple[int] = None,
random_crop=True,
p_deform_per_sample=1,
Expand All @@ -41,6 +42,7 @@ def __init__(
self.label_key = label_key
self.skip_label = skip_label
self.do_crop = crop
self.cval = cval
self.patch_size = patch_size
self.random_crop = random_crop

Expand Down Expand Up @@ -98,6 +100,11 @@ def __CropDeformRotateScale__(
):
if not self.do_crop:
patch_size = imageVolume.shape[2:]
if self.cval == "min":
cval = float(imageVolume.min())
else:
cval = self.cval
assert isinstance(cval, (int, float)), f"got {cval} of type {type(cval)}"

coords = create_zero_centered_coordinate_matrix(patch_size)
imageCanvas = np.zeros((imageVolume.shape[0], imageVolume.shape[1], *patch_size), dtype=np.float32)
Expand Down Expand Up @@ -149,7 +156,7 @@ def __CropDeformRotateScale__(
coords,
order=3,
mode="constant",
cval=0.0,
cval=cval,
).astype(imageVolume.dtype)

if not skip_label:
Expand Down
3 changes: 1 addition & 2 deletions yucca/planning/YuccaPlanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ def plan(self):
self.determine_transpose()
self.determine_target_size_from_fixed_size_or_spacing()
self.validate_target_size()
self.drop_keys_from_dict(dict=self.dataset_properties, keys=["original_sizes", "original_spacings"])

self.drop_keys_from_dict(dict=self.dataset_properties, keys=[])
self.populate_plans_file()

save_json(self.plans, self.plans_path, sort_keys=False)
Expand Down
7 changes: 5 additions & 2 deletions yucca/planning/resampling/YuccaPlanner_224x224.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@


class YuccaPlanner_224x224(YuccaPlanner):
def __init__(self, task, preprocessor="YuccaPreprocessor", threads=None, disable_unittests=False, view=None):
super().__init__(task, preprocessor, threads, disable_unittests, view)

def __init__(self, task, preprocessor="YuccaPreprocessor", threads=None, disable_sanity_checks=False, view=None):
super().__init__(
task, preprocessor=preprocessor, threads=threads, disable_sanity_checks=disable_sanity_checks, view=view
)
self.name = str(self.__class__.__name__) + str(view or "")

def determine_target_size_from_fixed_size_or_spacing(self):
Expand Down
3 changes: 3 additions & 0 deletions yucca/preprocessing/YuccaPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,10 @@ def analyze_label(self, label):
# we get some (no need to get all) locations of foreground, that we will later use in the
# oversampling of foreground classes
# And we also potentially analyze the connected components of the label
max_foreground_locs = 100000 # limited to save space
foreground_locs = np.array(np.nonzero(label)).T[::10].tolist()
if len(foreground_locs) > max_foreground_locs:
foreground_locs = foreground_locs[:: round(len(foreground_locs) / max_foreground_locs)]
if not self.enable_cc_analysis:
label_cc_n = 0
label_cc_sizes = 0
Expand Down
13 changes: 13 additions & 0 deletions yucca/run/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def main():
required=False,
)
parser.add_argument("--obj_eval", help="enable object evaluation", action="store_true", default=None, required=False)
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
required=False,
help="Overwrite existing predictions",
)
parser.add_argument("--split_idx", type=int, help="idx of splits to use for training.", default=0)
parser.add_argument(
"--split_data_method", help="Specify splitting method. Either kfold, simple_train_val_split", default="kfold"
Expand All @@ -76,6 +83,8 @@ def main():
help="Specify the parameter for the selected split method. For KFold use an int, for simple_split use a float between 0.0-1.0.",
default=5,
)
parser.add_argument("--surface_eval", help="enable surface evaluation", action="store_true", default=None, required=False)

parser.add_argument(
"--version", "-v", help="version number of the model. Defaults to 0.", default=0, type=int, required=False
)
Expand All @@ -89,6 +98,7 @@ def main():
plan_id = args.pl
checkpoint = args.chk
obj = args.obj_eval
overwrite = args.overwrite
as_binary = args.as_binary
classes = args.c
predpath = args.pred
Expand All @@ -99,6 +109,7 @@ def main():
split_idx = args.split_idx
split_data_method = args.split_data_method
split_data_param = args.split_data_param
surface_eval = args.surface_eval
assert (predpath and gtpath) or source_task, "Either supply BOTH paths or the source task"

if not predpath:
Expand All @@ -123,6 +134,8 @@ def main():
folder_with_predictions=predpath,
folder_with_ground_truth=gtpath,
do_object_eval=obj,
do_surface_eval=surface_eval,
overwrite=overwrite,
as_binary=as_binary,
task_type=task_type,
use_wandb=use_wandb,
Expand Down
Loading

0 comments on commit 69d0317

Please sign in to comment.