diff --git a/minigpt4/common/optims.py b/minigpt4/common/optims.py index 58327f72..270e66bf 100644 --- a/minigpt4/common/optims.py +++ b/minigpt4/common/optims.py @@ -80,7 +80,7 @@ def step(self, cur_epoch, cur_step): total_cur_step = cur_epoch * self.iters_per_epoch + cur_step if total_cur_step < self.warmup_steps: warmup_lr_schedule( - step=cur_step, + step=total_cur_step, optimizer=self.optimizer, max_step=self.warmup_steps, init_lr=self.warmup_start_lr,