diff --git a/pina/plotter.py b/pina/plotter.py index 7eebf63c..adeb8957 100644 --- a/pina/plotter.py +++ b/pina/plotter.py @@ -256,7 +256,7 @@ def plot_loss(self, ) loss = trainer_metrics[metric] epochs = range(len(loss)) - plt.plot(epochs, loss, **kwargs) + plt.plot(epochs, loss.cpu(), **kwargs) # plotting plt.xlabel('epoch')