Skip to content

Commit

Permalink
fix: fix average training time for restart
Browse files Browse the repository at this point in the history
Fix deepmodeling#4208.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 14, 2024
1 parent a1f8672 commit 999dd56
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
14 changes: 8 additions & 6 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,9 @@ def log_loss_valid(_task_key="Default"):
)
# the first training time is not accurate
if (
_step_id + 1
) > self.disp_freq or self.num_steps < 2 * self.disp_freq:
(_step_id + 1 - self.start_step) > self.disp_freq
or self.num_steps - self.start_step < 2 * self.disp_freq
):
self.total_train_time += train_time

if fout:
Expand Down Expand Up @@ -981,13 +982,14 @@ def log_loss_valid(_task_key="Default"):
with open("checkpoint", "w") as f:
f.write(str(self.latest_model))

if self.timing_in_training and self.num_steps // self.disp_freq > 0:
if self.num_steps >= 2 * self.disp_freq:
elapsed_batch = self.num_steps - self.start_step
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
if self.start_step >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batch (exclude first %d batches)",
self.total_train_time
/ (
self.num_steps // self.disp_freq * self.disp_freq
elapsed_batch // self.disp_freq * self.disp_freq
- self.disp_freq
),
self.disp_freq,
Expand All @@ -996,7 +998,7 @@ def log_loss_valid(_task_key="Default"):
log.info(
"average training time: %.4f s/batch",
self.total_train_time
/ (self.num_steps // self.disp_freq * self.disp_freq),
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)

if JIT:
Expand Down
20 changes: 15 additions & 5 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def train(self, train_data=None, valid_data=None):
fp = open(self.disp_file, "a")

cur_batch = run_sess(self.sess, self.global_step)
start_batch = cur_batch
elapsed_batch = stop_batch - start_batch
is_first_step = True
self.cur_batch = cur_batch
log.info(
Expand Down Expand Up @@ -552,7 +554,10 @@ def train(self, train_data=None, valid_data=None):
)
)
# the first training time is not accurate
if cur_batch > self.disp_freq or stop_batch < 2 * self.disp_freq:
if (
cur_batch - start_batch > self.disp_freq
or elapsed_batch < 2 * self.disp_freq
):
total_train_time += train_time
train_time = 0
wall_time_tic = toc
Expand Down Expand Up @@ -594,18 +599,23 @@ def train(self, train_data=None, valid_data=None):
self.save_checkpoint(cur_batch)
if self.run_opt.is_chief:
fp.close()
if self.timing_in_training and stop_batch // self.disp_freq > 0:
if stop_batch >= 2 * self.disp_freq:
elapsed_batch = stop_batch - start_batch
if self.timing_in_training and elapsed_batch // self.disp_freq > 0:
if elapsed_batch >= 2 * self.disp_freq:
log.info(
"average training time: %.4f s/batch (exclude first %d batches)",
total_train_time
/ (stop_batch // self.disp_freq * self.disp_freq - self.disp_freq),
/ (
elapsed_batch // self.disp_freq * self.disp_freq
- self.disp_freq
),
self.disp_freq,
)
else:
log.info(
"average training time: %.4f s/batch",
total_train_time / (stop_batch // self.disp_freq * self.disp_freq),
total_train_time
/ (elapsed_batch // self.disp_freq * self.disp_freq),
)

if self.profiling and self.run_opt.is_chief:
Expand Down

0 comments on commit 999dd56

Please sign in to comment.