Skip to content

Commit

Permalink
Slightly simplify main.py objective calculation
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Aug 5, 2024
1 parent dabb2b4 commit 241e891
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from omegaconf import DictConfig

from project.configs.config import Config
from project.datamodules.image_classification.image_classification import (
ImageClassificationDataModule,
)
from project.experiment import Experiment, setup_experiment
from project.utils.hydra_utils import resolve_dictconfig
from project.utils.utils import print_config
Expand Down Expand Up @@ -132,23 +129,23 @@ def evaluation(experiment: Experiment) -> tuple[str, float | None, dict]:
rich.print("RUN FAILED!")
return "fail", None, {}

returned_results_dict = results[0]
results_dict = results[0].copy()
returned_results_dict = dict(results[0])
results_dict = dict(results[0]).copy()

loss = results_dict.pop(f"{results_type}/loss")
if (
isinstance(experiment.datamodule, ImageClassificationDataModule)
and f"{results_type}/accuracy" in results_dict
):

if f"{results_type}/accuracy" in results_dict:
accuracy: float = results_dict[f"{results_type}/accuracy"]
top5_accuracy: float | None = results_dict.get(f"{results_type}/top5_accuracy")
rich.print(f"{results_type} top1 accuracy: {accuracy:.1%}")
if top5_accuracy is not None:
rich.print(f"{results_type} accuracy: {accuracy:.1%}")

if top5_accuracy := results_dict.get(f"{results_type}/top5_accuracy") is not None:
rich.print(f"{results_type} top5 accuracy: {top5_accuracy:.1%}")
# NOTE: This is the value that is used for HParam sweeps.
error = 1 - accuracy
metric_name = "1-accuracy"
else:
logger.warning("Assuming that the objective to minimize is the loss metric.")
# If 'accuracy' isn't in the results, assume that the loss is the metric to use.
metric_name = "loss"
error = loss

Expand Down

0 comments on commit 241e891

Please sign in to comment.