Skip to content

Commit

Permalink
add kwarg support to scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
simplymathematics committed Aug 4, 2024
1 parent c7dac03 commit 60a37df
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions deckard/base/scorer/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Dict, List
from hydra.utils import call
from hydra.errors import InstantiationException
from omegaconf import DictConfig, OmegaConf, ListConfig
from omegaconf import DictConfig, OmegaConf, ListConfig, ListMergeMode
import numpy as np
import json
from pathlib import Path
Expand Down Expand Up @@ -52,13 +52,11 @@ def __init__(
def __hash__(self):
return int(my_hash(self), 16)

def score(self, ind, dep) -> float:
def score(self, ind, dep, **kwargs) -> float:
args = deepcopy(self.args)
kwargs = deepcopy(self.params)
kwargs = OmegaConf.merge(self.params, kwargs, list_merge_mode=ListMergeMode.REPLACE)
new_args = []
i = 0
for arg in args:
i += 1
if arg in ["y_pred", "y_train", "y_test"]:
new_args.append(dep)
elif arg in ["labels", "y_true", "ground_truth"]:
Expand All @@ -70,7 +68,6 @@ def score(self, ind, dep) -> float:
config.update(kwargs)
try:
result = call(config, *args, **kwargs)

except InstantiationException as e: # pragma: no cover
if "continuous-multioutput" in str(e) or "multiclass-multioutput" in str(e):
new_args = []
Expand Down Expand Up @@ -161,6 +158,7 @@ def __call__(
score_dict_file=None,
labels_file=None,
predictions_file=None,
**kwargs,
):
new_scores = {}
args = list(args)
Expand All @@ -179,7 +177,7 @@ def __call__(
else:
pass
for name, scorer in self:
score = scorer.score(*args)
score = scorer.score(*args, **kwargs)
new_scores[name] = score
if score_dict_file is not None:
scores = self.load(score_dict_file)
Expand Down

0 comments on commit 60a37df

Please sign in to comment.