diff --git a/deepmd/env.py b/deepmd/env.py index b1d4958ed8..1a8da63f8e 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging import os +from typing import ( + Tuple, +) import numpy as np @@ -26,3 +30,82 @@ "low. Please set precision with environmental variable " "DP_INTERFACE_PREC." % dp_float_prec ) + + +def set_env_if_empty(key: str, value: str, verbose: bool = True): + """Set environment variable only if it is empty. + + Parameters + ---------- + key : str + env variable name + value : str + env variable value + verbose : bool, optional + if True action will be logged, by default True + """ + if os.environ.get(key) is None: + os.environ[key] = value + if verbose: + logging.warning( + f"Environment variable {key} is empty. Use the default value {value}" + ) + + +def set_default_nthreads(): + """Set internal number of threads to default=automatic selection. + + Notes + ----- + `DP_INTRA_OP_PARALLELISM_THREADS` and `DP_INTER_OP_PARALLELISM_THREADS` + control configuration of multithreading. + """ + if ( + "OMP_NUM_THREADS" not in os.environ + # for backward compatibility + or ( + "DP_INTRA_OP_PARALLELISM_THREADS" not in os.environ + and "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ + ) + or ( + "DP_INTER_OP_PARALLELISM_THREADS" not in os.environ + and "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ + ) + ): + logging.warning( + "To get the best performance, it is recommended to adjust " + "the number of threads by setting the environment variables " + "OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and " + "DP_INTER_OP_PARALLELISM_THREADS. See " + "https://deepmd.rtfd.io/parallelism/ for more information." + ) + if "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ: + set_env_if_empty("DP_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False) + if "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ: + set_env_if_empty("DP_INTER_OP_PARALLELISM_THREADS", "0", verbose=False) + + +def get_default_nthreads() -> Tuple[int, int]: + """Get paralellism settings. + + The method will first read the environment variables with the prefix `DP_`. + If not found, it will read the environment variables with the prefix `TF_` + for backward compatibility. + + Returns + ------- + Tuple[int, int] + number of `DP_INTRA_OP_PARALLELISM_THREADS` and + `DP_INTER_OP_PARALLELISM_THREADS` + """ + return int( + os.environ.get( + "DP_INTRA_OP_PARALLELISM_THREADS", + os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + ) + ), int( + os.environ.get( + "DP_INTER_OP_PARALLELISM_THREADS", + os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + ) + ) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e4c672765b..ee0e7a54cc 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -438,6 +438,12 @@ def warm_up_linear(step, warmup_steps): assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" self.model_prob = self.model_prob / sum_prob + # Tensorboard + self.enable_tensorboard = training_params.get("tensorboard", False) + self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") + self.tensorboard_freq = training_params.get("tensorboard_freq", 1) + self.enable_profiler = training_params.get("enable_profiler", False) + def run(self): fout = ( open(self.disp_file, mode="w", buffering=1) if self.rank == 0 else None @@ -448,8 +454,27 @@ def run(self): logging.info("Start to train %d steps.", self.num_steps) if dist.is_initialized(): logging.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") + if self.enable_tensorboard: + from torch.utils.tensorboard import ( + SummaryWriter, + ) + + writer = SummaryWriter(log_dir=self.tensorboard_log_dir) + if self.enable_profiler: + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + self.tensorboard_log_dir + ), + record_shapes=True, + with_stack=True, + ) + prof.start() def step(_step_id, task_key="Default"): + # PyTorch Profiler + if self.enable_profiler: + prof.step() self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] @@ -654,6 +679,13 @@ def log_loss_valid(_task_key="Default"): with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + # tensorboard + if self.enable_tensorboard and _step_id % self.tensorboard_freq == 0: + writer.add_scalar(f"{task_key}/lr", cur_lr, _step_id) + writer.add_scalar(f"{task_key}/loss", loss, _step_id) + for item in more_loss: + writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id) + self.t0 = time.time() for step_id in range(self.num_steps): if step_id < self.start_step: @@ -691,6 +723,10 @@ def log_loss_valid(_task_key="Default"): fout.close() if SAMPLER_RECORD: fout1.close() + if self.enable_tensorboard: + writer.close() + if self.enable_profiler: + prof.stop() def save_model(self, save_path, lr=0.0, step=0): module = self.wrapper.module if dist.is_initialized() else self.wrapper diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 559dba0167..b51b03fdc2 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -4,6 +4,11 @@ import numpy as np import torch +from deepmd.env import ( + get_default_nthreads, + set_default_nthreads, +) + PRECISION = os.environ.get("PRECISION", "float64") GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION) GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION) @@ -37,3 +42,11 @@ "double": torch.float64, } DEFAULT_PRECISION = "float64" + +# throw warnings if threads not set +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +if inter_nthreads > 0: # the behavior of 0 is not documented + torch.set_num_interop_threads(inter_nthreads) +if intra_nthreads > 0: + torch.set_num_threads(intra_nthreads) diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 993768c4a4..6bc89664c7 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -2,7 +2,6 @@ """Module that sets tensorflow working environment and exports inportant constants.""" import ctypes -import logging import os import platform from configparser import ( @@ -19,7 +18,6 @@ TYPE_CHECKING, Any, Dict, - Tuple, ) import numpy as np @@ -31,8 +29,15 @@ from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.env import get_default_nthreads as get_tf_default_nthreads +from deepmd.env import ( global_float_prec, ) +from deepmd.env import set_default_nthreads as set_tf_default_nthreads +from deepmd.env import ( + set_env_if_empty, +) if TYPE_CHECKING: from types import ( @@ -216,26 +221,6 @@ def dlopen_library(module: str, filename: str): } -def set_env_if_empty(key: str, value: str, verbose: bool = True): - """Set environment variable only if it is empty. - - Parameters - ---------- - key : str - env variable name - value : str - env variable value - verbose : bool, optional - if True action will be logged, by default True - """ - if os.environ.get(key) is None: - os.environ[key] = value - if verbose: - logging.warning( - f"Environment variable {key} is empty. Use the default value {value}" - ) - - def set_mkl(): """Tuning MKL for the best performance. @@ -270,44 +255,6 @@ def set_mkl(): reload(np) -def set_tf_default_nthreads(): - """Set TF internal number of threads to default=automatic selection. - - Notes - ----- - `TF_INTRA_OP_PARALLELISM_THREADS` and `TF_INTER_OP_PARALLELISM_THREADS` - control TF configuration of multithreading. - """ - if ( - "OMP_NUM_THREADS" not in os.environ - or "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ - or "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ - ): - logging.warning( - "To get the best performance, it is recommended to adjust " - "the number of threads by setting the environment variables " - "OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and " - "TF_INTER_OP_PARALLELISM_THREADS. See " - "https://deepmd.rtfd.io/parallelism/ for more information." - ) - set_env_if_empty("TF_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False) - set_env_if_empty("TF_INTER_OP_PARALLELISM_THREADS", "0", verbose=False) - - -def get_tf_default_nthreads() -> Tuple[int, int]: - """Get TF paralellism settings. - - Returns - ------- - Tuple[int, int] - number of `TF_INTRA_OP_PARALLELISM_THREADS` and - `TF_INTER_OP_PARALLELISM_THREADS` - """ - return int(os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0")), int( - os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0") - ) - - def get_tf_session_config() -> Any: """Configure tensorflow session. diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 31b54b4d76..dbe4881952 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1703,7 +1703,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_time_training = "Timing durining training." doc_profiling = "Profiling during training." doc_profiling_file = "Output file for profiling." - doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) to analyze performance. The log will be saved to `tensorboard_log_dir`." + doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler to analyze performance. The log will be saved to `tensorboard_log_dir`." doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events." diff --git a/doc/train/tensorboard.md b/doc/train/tensorboard.md index 1d6c5f0d68..a6cfdccb68 100644 --- a/doc/train/tensorboard.md +++ b/doc/train/tensorboard.md @@ -1,7 +1,7 @@ -# TensorBoard Usage {{ tensorflow_icon }} +# TensorBoard Usage {{ tensorflow_icon }} {{ pytorch_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }} ::: TensorBoard provides the visualization and tooling needed for machine learning diff --git a/doc/troubleshooting/howtoset_num_nodes.md b/doc/troubleshooting/howtoset_num_nodes.md index 8a9beab857..18b1a133ee 100644 --- a/doc/troubleshooting/howtoset_num_nodes.md +++ b/doc/troubleshooting/howtoset_num_nodes.md @@ -22,10 +22,10 @@ Sometimes, `$num_nodes` and the nodes information can be directly given by the H ## Parallelism between independent operators -For CPU devices, TensorFlow use multiple streams to run independent operators (OP). +For CPU devices, TensorFlow and PyTorch use multiple streams to run independent operators (OP). ```bash -export TF_INTER_OP_PARALLELISM_THREADS=3 +export DP_INTER_OP_PARALLELISM_THREADS=3 ``` However, for GPU devices, TensorFlow uses only one compute stream and multiple copy streams. @@ -33,20 +33,35 @@ Note that some of DeePMD-kit OPs do not have GPU support, so it is still encoura ## Parallelism within an individual operators -For CPU devices, `TF_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow native OPs when TensorFlow is built against Eigen. +For CPU devices, `DP_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow (when TensorFlow is built against Eigen) and PyTorch native OPs. ```bash -export TF_INTRA_OP_PARALLELISM_THREADS=2 +export DP_INTRA_OP_PARALLELISM_THREADS=2 ``` -`OMP_NUM_THREADS` is threads for OpenMP parallelism. It controls parallelism within TensorFlow native OPs when TensorFlow is built by Intel OneDNN and DeePMD-kit custom CPU OPs. -It may also control parallelsim for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable. +`OMP_NUM_THREADS` is the number of threads for OpenMP parallelism. +It controls parallelism within TensorFlow (when TensorFlow is built upon Intel OneDNN) and PyTorch (when PyTorch is built upon OpenMP) native OPs and DeePMD-kit custom CPU OPs. +It may also control parallelism for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable. ```bash export OMP_NUM_THREADS=2 ``` -There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`. See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information. +There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`. + +::::{tab-set} + +:::{tab-item} TensorFlow {{ tensorflow_icon }} + +See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information. + +::: +:::{tab-item} PyTorch {{ pytorch_icon }} + +See [PyTorch documentation](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html) for detailed information. + +::: +:::: ## Tune the performance @@ -56,8 +71,8 @@ Here are some empirical examples. If you wish to use 3 cores of 2 CPUs on one node, you may set the environmental variables and run DeePMD-kit as follows: ```bash export OMP_NUM_THREADS=3 -export TF_INTRA_OP_PARALLELISM_THREADS=3 -export TF_INTER_OP_PARALLELISM_THREADS=2 +export DP_INTRA_OP_PARALLELISM_THREADS=3 +export DP_INTER_OP_PARALLELISM_THREADS=2 dp train input.json ``` @@ -65,8 +80,8 @@ For a node with 128 cores, it is recommended to start with the following variabl ```bash export OMP_NUM_THREADS=16 -export TF_INTRA_OP_PARALLELISM_THREADS=16 -export TF_INTER_OP_PARALLELISM_THREADS=8 +export DP_INTRA_OP_PARALLELISM_THREADS=16 +export DP_INTER_OP_PARALLELISM_THREADS=8 ``` Again, in general, one should make sure the product of the parallel numbers is less than or equal to the number of cores available. diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 0392747979..72382169f8 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -144,9 +144,9 @@ void select_map_inv(typename std::vector::iterator out, * @brief Get the number of threads from the environment variable. * @details A warning will be thrown if environmental variables are not set. * @param[out] num_intra_nthreads The number of intra threads. Read from - *TF_INTRA_OP_PARALLELISM_THREADS. + *DP_INTRA_OP_PARALLELISM_THREADS. * @param[out] num_inter_nthreads The number of inter threads. Read from - *TF_INTER_OP_PARALLELISM_THREADS. + *DP_INTER_OP_PARALLELISM_THREADS. **/ void get_env_nthreads(int& num_intra_nthreads, int& num_inter_nthreads); diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 2923534fb7..d2923c8d9e 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -330,23 +330,36 @@ void deepmd::get_env_nthreads(int& num_intra_nthreads, num_intra_nthreads = 0; num_inter_nthreads = 0; const char* env_intra_nthreads = - std::getenv("TF_INTRA_OP_PARALLELISM_THREADS"); + std::getenv("DP_INTRA_OP_PARALLELISM_THREADS"); const char* env_inter_nthreads = + std::getenv("DP_INTER_OP_PARALLELISM_THREADS"); + // backward compatibility + const char* env_intra_nthreads_tf = + std::getenv("TF_INTRA_OP_PARALLELISM_THREADS"); + const char* env_inter_nthreads_tf = std::getenv("TF_INTER_OP_PARALLELISM_THREADS"); const char* env_omp_nthreads = std::getenv("OMP_NUM_THREADS"); if (env_intra_nthreads && std::string(env_intra_nthreads) != std::string("") && atoi(env_intra_nthreads) >= 0) { num_intra_nthreads = atoi(env_intra_nthreads); + } else if (env_intra_nthreads_tf && + std::string(env_intra_nthreads_tf) != std::string("") && + atoi(env_intra_nthreads_tf) >= 0) { + num_intra_nthreads = atoi(env_intra_nthreads_tf); } else { - throw_env_not_set_warning("TF_INTRA_OP_PARALLELISM_THREADS"); + throw_env_not_set_warning("DP_INTRA_OP_PARALLELISM_THREADS"); } if (env_inter_nthreads && std::string(env_inter_nthreads) != std::string("") && atoi(env_inter_nthreads) >= 0) { num_inter_nthreads = atoi(env_inter_nthreads); + } else if (env_inter_nthreads_tf && + std::string(env_inter_nthreads_tf) != std::string("") && + atoi(env_inter_nthreads_tf) >= 0) { + num_inter_nthreads = atoi(env_inter_nthreads_tf); } else { - throw_env_not_set_warning("TF_INTER_OP_PARALLELISM_THREADS"); + throw_env_not_set_warning("DP_INTER_OP_PARALLELISM_THREADS"); } if (!(env_omp_nthreads && std::string(env_omp_nthreads) != std::string("") && atoi(env_omp_nthreads) >= 0)) { diff --git a/source/tests/tf/test_env.py b/source/tests/tf/test_env.py index eb1b40e707..cd066b06a5 100644 --- a/source/tests/tf/test_env.py +++ b/source/tests/tf/test_env.py @@ -19,8 +19,8 @@ def test_empty(self): @mock.patch.dict( "os.environ", values={ - "TF_INTRA_OP_PARALLELISM_THREADS": "5", - "TF_INTER_OP_PARALLELISM_THREADS": "3", + "DP_INTRA_OP_PARALLELISM_THREADS": "5", + "DP_INTER_OP_PARALLELISM_THREADS": "3", }, ) def test_given(self):