Skip to content

Commit

Permalink
fix(pt): make PT training step idx consistent with TF
Browse files Browse the repository at this point in the history
Fix deepmodeling#4206.
Currently, the training step index displayed in TF and PT has different meanings:
- In TF, step 0 means no training; step 1 means a training step has been performed. The maximum training step is equal to the number of steps.
- In PT, step 0 means a training step has been performed. The maximum training step is the number of steps minus 1.
This PR corrects the defination of the step index in PT and makes them consistent.
There are still a difference: TF shows step 0 but PT shows step 1. Showing step 0 in PT needs heavy refactor and thus is not included in this PR.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 15, 2024
1 parent 16172e6 commit 8f41dde
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,10 @@ def fake_model():
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

# Log and persist
if self.display_in_training and _step_id % self.disp_freq == 0:
display_step_id = _step_id + 1
if self.display_in_training and (
display_step_id % self.disp_freq == 0 or display_step_id == 1
):
self.wrapper.eval()

def log_loss_train(_loss, _more_loss, _task_key="Default"):
Expand Down Expand Up @@ -821,7 +824,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
Expand All @@ -830,7 +833,7 @@ def log_loss_valid(_task_key="Default"):
if valid_results:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name="val",
rmse=valid_results,
learning_rate=None,
Expand Down Expand Up @@ -861,7 +864,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name=_key + "_trn",
rmse=train_results[_key],
learning_rate=cur_lr,
Expand All @@ -870,7 +873,7 @@ def log_loss_valid(_task_key="Default"):
if valid_results[_key]:
log.info(
format_training_message_per_task(
batch=_step_id,
batch=display_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
Expand All @@ -883,7 +886,7 @@ def log_loss_valid(_task_key="Default"):
if self.rank == 0 and self.timing_in_training:
log.info(
format_training_message(
batch=_step_id,
batch=display_step_id,
wall_time=train_time,
)
)
Expand All @@ -899,7 +902,7 @@ def log_loss_valid(_task_key="Default"):
self.print_header(fout, train_results, valid_results)
self.lcurve_should_print_header = False
self.print_on_training(
fout, _step_id, cur_lr, train_results, valid_results
fout, display_step_id, cur_lr, train_results, valid_results
)

if (
Expand All @@ -921,11 +924,13 @@ def log_loss_valid(_task_key="Default"):
f.write(str(self.latest_model))

# tensorboard
if self.enable_tensorboard and _step_id % self.tensorboard_freq == 0:
writer.add_scalar(f"{task_key}/lr", cur_lr, _step_id)
writer.add_scalar(f"{task_key}/loss", loss, _step_id)
if self.enable_tensorboard and display_step_id % self.tensorboard_freq == 0:
writer.add_scalar(f"{task_key}/lr", cur_lr, display_step_id)
writer.add_scalar(f"{task_key}/loss", loss, display_step_id)
for item in more_loss:
writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id)
writer.add_scalar(
f"{task_key}/{item}", more_loss[item], display_step_id
)

self.t0 = time.time()
self.total_train_time = 0.0
Expand Down

0 comments on commit 8f41dde

Please sign in to comment.