From 60a37df2184b69cfcf6cc8900f537fc43d8af1f4 Mon Sep 17 00:00:00 2001 From: Charlie Meyers Date: Sun, 4 Aug 2024 21:41:25 +0200 Subject: [PATCH] add kwarg support to scorer --- deckard/base/scorer/scorer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/deckard/base/scorer/scorer.py b/deckard/base/scorer/scorer.py index 52429e44..9e3eb9d5 100644 --- a/deckard/base/scorer/scorer.py +++ b/deckard/base/scorer/scorer.py @@ -2,7 +2,7 @@ from typing import Literal, Dict, List from hydra.utils import call from hydra.errors import InstantiationException -from omegaconf import DictConfig, OmegaConf, ListConfig +from omegaconf import DictConfig, OmegaConf, ListConfig, ListMergeMode import numpy as np import json from pathlib import Path @@ -52,13 +52,11 @@ def __init__( def __hash__(self): return int(my_hash(self), 16) - def score(self, ind, dep) -> float: + def score(self, ind, dep, **kwargs) -> float: args = deepcopy(self.args) - kwargs = deepcopy(self.params) + kwargs = OmegaConf.merge(self.params, kwargs, list_merge_mode=ListMergeMode.REPLACE) new_args = [] - i = 0 for arg in args: - i += 1 if arg in ["y_pred", "y_train", "y_test"]: new_args.append(dep) elif arg in ["labels", "y_true", "ground_truth"]: @@ -70,7 +68,6 @@ def score(self, ind, dep) -> float: config.update(kwargs) try: result = call(config, *args, **kwargs) - except InstantiationException as e: # pragma: no cover if "continuous-multioutput" in str(e) or "multiclass-multioutput" in str(e): new_args = [] @@ -161,6 +158,7 @@ def __call__( score_dict_file=None, labels_file=None, predictions_file=None, + **kwargs, ): new_scores = {} args = list(args) @@ -179,7 +177,7 @@ def __call__( else: pass for name, scorer in self: - score = scorer.score(*args) + score = scorer.score(*args, **kwargs) new_scores[name] = score if score_dict_file is not None: scores = self.load(score_dict_file)