Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format training logging #3397

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/loggers/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
)
CFORMATTER = logging.Formatter(
# "%(app_name)s %(levelname)-7s |-> %(name)-45s %(message)s"
"%(app_name)s %(levelname)-7s %(message)s"
"[%(asctime)s] %(app_name)s %(levelname)-7s %(message)s"
)
FFORMATTER_MPI = logging.Formatter(
"[%(asctime)s] %(app_name)s rank:%(rank)-2s %(levelname)-7s %(name)-45s %(message)s"
)
CFORMATTER_MPI = logging.Formatter(
# "%(app_name)s rank:%(rank)-2s %(levelname)-7s |-> %(name)-45s %(message)s"
"%(app_name)s rank:%(rank)-2s %(levelname)-7s %(message)s"
"[%(asctime)s] %(app_name)s rank:%(rank)-2s %(levelname)-7s %(message)s"
)


Expand Down
34 changes: 34 additions & 0 deletions deepmd/loggers/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (

Check warning on line 2 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L2

Added line #L2 was not covered by tests
Dict,
Optional,
)


def format_training_message(

Check warning on line 8 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L8

Added line #L8 was not covered by tests
batch: int,
wall_time: float,
):
"""Format a training message."""
return f"batch {batch:7d}: " f"total wall time = {wall_time:.2f} s"

Check warning on line 13 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L13

Added line #L13 was not covered by tests


def format_training_message_per_task(

Check warning on line 16 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L16

Added line #L16 was not covered by tests
batch: int,
task_name: str,
rmse: Dict[str, float],
learning_rate: Optional[float],
):
if task_name:
task_name += ": "
if learning_rate is None:
lr = ""

Check warning on line 25 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L22-L25

Added lines #L22 - L25 were not covered by tests
else:
lr = f", lr = {learning_rate:8.2e}"

Check warning on line 27 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L27

Added line #L27 was not covered by tests
# sort rmse
rmse = dict(sorted(rmse.items()))
return (

Check warning on line 30 in deepmd/loggers/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/loggers/training.py#L29-L30

Added lines #L29 - L30 were not covered by tests
f"batch {batch:7d}: {task_name}"
f"{', '.join([f'{kk} = {vv:8.2e}' for kk, vv in rmse.items()])}"
f"{lr}"
)
75 changes: 51 additions & 24 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.loggers.training import (

Check warning on line 22 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L22

Added line #L22 was not covered by tests
format_training_message,
format_training_message_per_task,
)
from deepmd.pt.loss import (
DenoiseLoss,
EnergyStdLoss,
Expand Down Expand Up @@ -656,33 +660,24 @@
# Log and persist
if _step_id % self.disp_freq == 0:
self.wrapper.eval()
msg = f"step={_step_id}, lr={cur_lr:.2e}"

def log_loss_train(_loss, _more_loss, _task_key="Default"):
results = {}
if not self.multi_task:
suffix = ""
else:
suffix = f"_{_task_key}"
_msg = f"loss{suffix}={_loss:.4f}"
rmse_val = {
item: _more_loss[item]
for item in _more_loss
if "l2_" not in item
}
for item in sorted(rmse_val.keys()):
_msg += f", {item}_train{suffix}={rmse_val[item]:.4f}"
results[item] = rmse_val[item]
return _msg, results
return results

def log_loss_valid(_task_key="Default"):
single_results = {}
sum_natoms = 0
if not self.multi_task:
suffix = ""
valid_numb_batch = self.valid_numb_batch
else:
suffix = f"_{_task_key}"
valid_numb_batch = self.valid_numb_batch[_task_key]
for ii in range(valid_numb_batch):
self.optimizer.zero_grad()
Expand All @@ -707,16 +702,28 @@
single_results.get(k, 0.0) + v * natoms
)
results = {k: v / sum_natoms for k, v in single_results.items()}
_msg = ""
for item in sorted(results.keys()):
_msg += f", {item}_valid{suffix}={results[item]:.4f}"
return _msg, results
return results

if not self.multi_task:
temp_msg, train_results = log_loss_train(loss, more_loss)
msg += "\n" + temp_msg
temp_msg, valid_results = log_loss_valid()
msg += temp_msg
train_results = log_loss_train(loss, more_loss)
valid_results = log_loss_valid()
log.info(

Check warning on line 710 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L708-L710

Added lines #L708 - L710 were not covered by tests
format_training_message_per_task(
batch=_step_id,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)

Check warning on line 716 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L712-L716

Added lines #L712 - L716 were not covered by tests
)
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=_step_id,

Check warning on line 721 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L718-L721

Added lines #L718 - L721 were not covered by tests
task_name="val",
rmse=valid_results,
learning_rate=None,

Check warning on line 724 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L724

Added line #L724 was not covered by tests
)
)

Check warning on line 726 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L726

Added line #L726 was not covered by tests
else:
train_results = {_key: {} for _key in self.model_keys}
valid_results = {_key: {} for _key in self.model_keys}
Expand All @@ -743,13 +750,33 @@
valid_msg[_key], valid_results[_key] = log_loss_valid(
_task_key=_key
)
msg += "\n" + train_msg[_key]
msg += valid_msg[_key]
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_trn",

Check warning on line 756 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L755-L756

Added lines #L755 - L756 were not covered by tests
rmse=train_results[_key],
learning_rate=cur_lr,
)
)
if valid_results is not None:
log.info(
format_training_message_per_task(
batch=_step_id,
task_name=_key + "_val",
rmse=valid_results[_key],
learning_rate=None,
)
)

Check warning on line 769 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L765-L769

Added lines #L765 - L769 were not covered by tests

train_time = time.time() - self.t0
self.t0 = time.time()
msg += f", speed={train_time:.2f} s/{self.disp_freq if _step_id else 1} batches"
log.info(msg)
current_time = time.time()
train_time = current_time - self.t0
self.t0 = current_time
log.info(
format_training_message(

Check warning on line 775 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L772-L775

Added lines #L772 - L775 were not covered by tests
batch=_step_id,
wall_time=train_time,
)

Check warning on line 778 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L778

Added line #L778 was not covered by tests
)

if fout:
if self.lcurve_should_print_header:
Expand Down
44 changes: 42 additions & 2 deletions deepmd/tf/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from deepmd.common import (
symlink_prefix_files,
)
from deepmd.loggers.training import (

Check warning on line 26 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L26

Added line #L26 was not covered by tests
format_training_message,
format_training_message_per_task,
)
from deepmd.tf.common import (
data_requirement,
get_precision,
Expand Down Expand Up @@ -774,8 +778,10 @@
test_time = toc - tic
wall_time = toc - wall_time_tic
log.info(
"batch %7d training time %.2f s, testing time %.2f s, total wall time %.2f s"
% (cur_batch, train_time, test_time, wall_time)
format_training_message(
batch=cur_batch,
wall_time=wall_time,
)
)
# the first training time is not accurate
if cur_batch > self.disp_freq or stop_batch < 2 * self.disp_freq:
Expand Down Expand Up @@ -959,6 +965,23 @@
for k in train_results.keys():
print_str += prop_fmt % (train_results[k])
print_str += " %8.1e\n" % cur_lr
log.info(

Check warning on line 968 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L968

Added line #L968 was not covered by tests
format_training_message_per_task(
batch=cur_batch,
task_name="trn",
rmse=train_results,
learning_rate=cur_lr,
)
)
if valid_results is not None:
log.info(

Check warning on line 977 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L976-L977

Added lines #L976 - L977 were not covered by tests
format_training_message_per_task(
batch=cur_batch,
task_name="val",
rmse=valid_results,
learning_rate=None,
)
)
else:
for fitting_key in train_results:
if valid_results[fitting_key] is not None:
Expand All @@ -974,6 +997,23 @@
for k in train_results[fitting_key].keys():
print_str += prop_fmt % (train_results[fitting_key][k])
print_str += " %8.1e\n" % cur_lr_dict[fitting_key]
log.info(

Check warning on line 1000 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L1000

Added line #L1000 was not covered by tests
format_training_message_per_task(
batch=cur_batch,
task_name=f"{fitting_key}_trn",
rmse=train_results[fitting_key],
learning_rate=cur_lr_dict[fitting_key],
)
)
if valid_results is not None:
log.info(

Check warning on line 1009 in deepmd/tf/train/trainer.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/train/trainer.py#L1008-L1009

Added lines #L1008 - L1009 were not covered by tests
format_training_message_per_task(
batch=cur_batch,
task_name=f"{fitting_key}_val",
rmse=valid_results[fitting_key],
learning_rate=None,
)
)
fp.write(print_str)
fp.flush()

Expand Down
Loading