Skip to content

Commit

Permalink
Support evaluators without losses or metrics during flatboard creation
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580566774
  • Loading branch information
Qwlouse authored and The kauldron Authors committed Nov 8, 2023
1 parent 3efd760 commit 3068e6a
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions kauldron/train/train_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def get_loss_y_keys(config: config_lib.Config) -> Sequence[str]:
loss_names = {k for k in tree_flatten_with_slash_path(config.train_losses)}
# evaluator losses
for evaluator in config.evals.values():
loss_names |= {k for k in tree_flatten_with_slash_path(evaluator.losses)}
eval_losses = getattr(evaluator, "losses", {})
loss_names |= {k for k in tree_flatten_with_slash_path(eval_losses)}

# If more than one loss, add the total loss
if len(loss_names) > 1:
Expand All @@ -345,7 +346,8 @@ def get_metric_y_keys(config: config_lib.Config) -> Sequence[str]:
metric_names = {k for k in tree_flatten_with_slash_path(config.train_metrics)}
# add evaluator metrics
for evaluator in config.evals.values():
metric_names |= {k for k in tree_flatten_with_slash_path(evaluator.metrics)}
eval_metrics = getattr(evaluator, "metrics", {})
metric_names |= {k for k in tree_flatten_with_slash_path(eval_metrics)}
return [f"metrics/{l.replace('.', '/')}" for l in sorted(metric_names)]


Expand Down

0 comments on commit 3068e6a

Please sign in to comment.