diff --git a/darwin/torch/dataset.py b/darwin/torch/dataset.py index bede4a58f..6e2cfb566 100644 --- a/darwin/torch/dataset.py +++ b/darwin/torch/dataset.py @@ -126,7 +126,7 @@ def get_target(self, index: int): idx = self.classes.index(tag) target[idx] = 1 else: - target = self.classes.index(tags[0]) + target = torch.tensor([self.classes.index(tags[0])]) return target