Skip to content

Commit

Permalink
feat(pt): support training/profiling argument in PT (deepmodeling#3897
Browse files Browse the repository at this point in the history
)

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Added profiling functionality with parameters for enabling profiling
and exporting data to a Chrome trace file.
  
- **Documentation**
- Updated documentation for profiling-related arguments to clarify
export options for performance analysis.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Jun 24, 2024
1 parent bbe5c4b commit 8889a1d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
17 changes: 13 additions & 4 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ def warm_up_linear(step, warmup_steps):
self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log")
self.tensorboard_freq = training_params.get("tensorboard_freq", 1)
self.enable_profiler = training_params.get("enable_profiler", False)
self.profiling = training_params.get("profiling", False)
self.profiling_file = training_params.get("profiling_file", "timeline.json")

def run(self):
fout = (
Expand All @@ -716,20 +718,22 @@ def run(self):
)

writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
if self.enable_profiler:
if self.enable_profiler or self.profiling:
prof = torch.profiler.profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(
self.tensorboard_log_dir
),
)
if self.enable_profiler
else None,
record_shapes=True,
with_stack=True,
)
prof.start()

def step(_step_id, task_key="Default"):
# PyTorch Profiler
if self.enable_profiler:
if self.enable_profiler or self.profiling:
prof.step()
self.wrapper.train()
if isinstance(self.lr_exp, dict):
Expand Down Expand Up @@ -1061,8 +1065,13 @@ def log_loss_valid(_task_key="Default"):
fout1.close()
if self.enable_tensorboard:
writer.close()
if self.enable_profiler:
if self.enable_profiler or self.profiling:
prof.stop()
if self.profiling:
prof.export_chrome_trace(self.profiling_file)
log.info(
f"The profiling trace have been saved to: {self.profiling_file}"
)

def save_model(self, save_path, lr=0.0, step=0):
module = (
Expand Down
8 changes: 4 additions & 4 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,9 +2339,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
)
doc_disp_training = "Displaying verbose information during training."
doc_time_training = "Timing durining training."
doc_profiling = "Profiling during training."
doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`."
doc_profiling_file = "Output file for profiling."
doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler to analyze performance. The log will be saved to `tensorboard_log_dir`."
doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`."
doc_tensorboard = "Enable tensorboard"
doc_tensorboard_log_dir = "The log directory of tensorboard outputs"
doc_tensorboard_freq = "The frequency of writing tensorboard events."
Expand Down Expand Up @@ -2397,14 +2397,14 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
bool,
optional=True,
default=False,
doc=doc_only_tf_supported + doc_profiling,
doc=doc_profiling,
),
Argument(
"profiling_file",
str,
optional=True,
default="timeline.json",
doc=doc_only_tf_supported + doc_profiling_file,
doc=doc_profiling_file,
),
Argument(
"enable_profiler",
Expand Down

0 comments on commit 8889a1d

Please sign in to comment.