diff --git a/lnn/model.py b/lnn/model.py index 1f0d64f..5816c27 100644 --- a/lnn/model.py +++ b/lnn/model.py @@ -670,7 +670,7 @@ def loss_fn(self, losses): f"expected losses from the following {[l.name for l in Loss]}" ) elif isinstance(losses, Loss): - losses = [losses] + losses = {losses: None} elif isinstance(losses, list): losses = {c: None for c in losses} result = list()