-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
General QOL improvements #143
Conversation
@asbjrnmunk and @jakobamb if you've any QOL improvements you would like to see that are simple/medium difficult to implement add here |
ad measure flops: When should this be run? It is a full forward pass, so it takes some time and would probably be a bit cumbersome to run always. Otherwise sounds good. |
@Sllambias for flop counting you can see https://github.com/asbjrnmunk/agave/blob/main/src/count_flops.py :) |
reviews of course conditioned on passing automatic test and the manual tests i've written above |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really cool! Only minor comments.
Regarding wandb images in classification i would just plot the input without labels, since the main use case here is finding bugs in preprocessing/data.
@@ -130,7 +142,7 @@ def __init__( | |||
self.labels = ["0", "1"] | |||
self.name += "_BINARY" | |||
|
|||
self.labelarr = np.array(self.labels, dtype=np.uint8) | |||
self.labelarr = np.sort(np.array(self.labels, dtype=np.uint8)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally i am not a fan of specifying the type in the variable name. If it is really important, i would use a type hint:
s: str = ''
@@ -271,7 +283,7 @@ def _evaluate_folder_segm(self): | |||
meandict = {} | |||
|
|||
for label in self.labels: | |||
meandict[label] = {k: [] for k in list(self.metrics.keys()) + self.obj_metrics} | |||
meandict[label] = {k: [] for k in list(self.metrics.keys()) + self.obj_metrics + self.surface_metrics} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not a fan of type in name.
surface_labeldict = get_surface_metrics_for_label(gt, pred, label, as_binary=self.as_binary) | ||
for k, v in surface_labeldict.items(): | ||
labeldict[k] = round(v, 4) | ||
meandict[str(label)][k].append(labeldict[k]) | ||
casedict[str(label)] = labeldict | ||
casedict["Prediction:"] = predpath |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not a fan of type in name.
else: | ||
pred = np.where(pred == label, 1, 0).astype(bool) | ||
gt = np.where(gt == label, 1, 0).astype(bool) | ||
surface_distances = metrics.compute_surface_distances( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline above.
mask_pred=pred, | ||
spacing_mm=spacing, | ||
) | ||
labeldict["Average Surface Distance"] = metrics.compute_surface_dice_at_tolerance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline above. And not a fan of type in name.
pred = pred.get_fdata() | ||
gt = gt.get_fdata() | ||
labeldict = {} | ||
if label == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline above.
spacing = get_nib_spacing(pred) | ||
pred = pred.get_fdata() | ||
gt = gt.get_fdata() | ||
labeldict = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, not a fan of type in name.
spacing_mm=spacing, | ||
) | ||
labeldict["Average Surface Distance"] = metrics.compute_surface_dice_at_tolerance( | ||
surface_distances=surface_distances, tolerance_mm=1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make tolerance a parameter with 1mm as default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you see it realistic that this might change. I've had other similar cases with connected components connectivity, and it hasnt really been that case that I needed to change it (in situations where it wasnt simply easier to import the function in a task-specific script)
yucca/planning/YuccaPlanner.py
Outdated
@@ -73,8 +73,9 @@ def plan(self): | |||
self.determine_transpose() | |||
self.determine_target_size_from_fixed_size_or_spacing() | |||
self.validate_target_size() | |||
self.drop_keys_from_dict(dict=self.dataset_properties, keys=["original_sizes", "original_spacings"]) | |||
|
|||
self.drop_keys_from_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you <3
Todo:
(WILL NOT IMPLEMENT) Add ThroughputMonitor
REASON: No added value and probably moderate overhead. Measure FLOPs does what we want.
Add measure_flops
This is now printed when models are instantiated (also pr. layer and also including # params)
Add option to overwrite in inference (so we can e.g. continue from broken runs)
We can now use --overwrite if we wish to overwrite predictions AND results_json. Otherwise only non-existing will be predicted/evaluated
Remove additional large keys in plans (new_sizes and new_spacings for now) - which means they will also not be in the ckpt and in wandb etc
Add lm.hparams to Augmenter
Now a part of the hparams.yaml and ckpt and wandb
Recursive Find Function will now let you know if it doesn't find anything (e.g. if you've got a typo in the Planner or Manager specified in the CLI)
Guess this is nice to find typos and bugs mainly
Evaluator bug fix (correct results, but weird looking output for models with > 10 classes)
WandB image plots
Very beautiful. Need to figure out what to do exactly with classification. For now it is not enabled.
(WILL NOT IMPLEMENT) Per label metrics during training
REASON: known Torchmetrics bug, waiting for them to fix it Dice metric cannot compute DICE score for each class Lightning-AI/torchmetrics#2282
Binary evaluation in inference
Run yucca_evaluation with "--as_binary" to achieve this.
Surface Dice (see: https://github.com/MIC-DKFZ/MedNeXt/blob/c5ed3f38b56d58c80581c75fea856865f42ddb75/nnunet_mednext/evaluation/region_based_evaluation.py#L112)
Run yucca_evaluation with "--surface_eval" to achieve this
To be fully tested with (along with pytests):