diff --git a/train_reader.py b/train_reader.py index b983c42..f760992 100644 --- a/train_reader.py +++ b/train_reader.py @@ -78,11 +78,11 @@ def train(model, optimizer, scheduler, step, train_dataset, eval_dataset, opt, c log += f"train: {curr_loss/opt.eval_freq:.3f} |" log += f"evaluation: {100*dev_em:.2f}EM |" log += f"lr: {scheduler.get_last_lr()[0]:.5f}" - logger.info(log) - curr_loss = 0 + logger.info(log) if tb_logger is not None: tb_logger.add_scalar("Evaluation", dev_em, step) tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), step) + curr_loss = 0. if opt.is_main and step % opt.save_freq == 0: src.util.save(model, optimizer, scheduler, step, best_dev_em,