diff --git a/main.py b/main.py index abe63b7..7437aff 100644 --- a/main.py +++ b/main.py @@ -150,8 +150,7 @@ def train(model, data, optimizer, ema, n_epoch=30, start_epoch=0, batch_size=arg logger.histo_summary(name + '/grad', to_np(param.grad), step) optimizer.zero_grad() - # (loss_p1+loss_p2).backward() - loss_p1.backward() + (loss_p1+loss_p2).backward() optimizer.step() for name, param in model.named_parameters(): if param.requires_grad: