From afb440a5057ad27de594595b7d8b32d1d93ac89e Mon Sep 17 00:00:00 2001 From: Anyang Peng <137014849+anyangml@users.noreply.github.com> Date: Wed, 31 Jan 2024 17:20:22 +0800 Subject: [PATCH 1/3] Feat: add pair table model to pytorch (#3192) Migrated from this [PR](https://github.com/dptech-corp/deepmd-pytorch/pull/174). This is to reimplement the PairTab Model in Pytorch. Notes: 1. Different from the tensorflow version, the pytorch version abstracts away all the post energy conversion operations (force, virial). 2. Added extrapolation when `rcut` > `rmax`. The pytorch version overwrite energy beyond extrapolation endpoint to `0`. These features are not available in the tensorflow version. The extrapolation uses a cubic spline form, the 1st order derivation for the starting point is estimated using the last two rows in the user defined table. See example below: ![img_v3_027k_b50c690d-dc2d-4803-bd2c-2e73aa3c73fg](https://github.com/deepmodeling/deepmd-kit/assets/137014849/f3efa4d3-795e-4ff8-acdc-642227f0e19c) ![img_v3_027k_8de38597-ef4e-4e5b-989e-dbd13cc93fag](https://github.com/deepmodeling/deepmd-kit/assets/137014849/493da26d-f01d-4dd0-8520-ea2d84e7b548) ![img_v3_027k_f8268564-3f5d-49e6-91d6-169a61d9347g](https://github.com/deepmodeling/deepmd-kit/assets/137014849/b8ad4d4d-a4a4-40f0-94d1-810006e7175b) ![img_v3_027k_3966ef67-dd5e-4f48-992e-c2763311451g](https://github.com/deepmodeling/deepmd-kit/assets/137014849/27f31e79-13c8-4ce8-9911-b4cc0ac8188c) --------- Co-authored-by: Anyang Peng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/pt/model/model/pair_tab.py | 312 ++++++++++++++++++ deepmd/utils/pair_tab.py | 154 ++++++++- .../tests/common/test_pairtab_preprocess.py | 263 +++++++++++++++ source/tests/pt/test_pairtab.py | 190 +++++++++++ 4 files changed, 914 insertions(+), 5 deletions(-) create mode 100644 deepmd/pt/model/model/pair_tab.py create mode 100644 source/tests/common/test_pairtab_preprocess.py create mode 100644 source/tests/pt/test_pairtab.py diff --git a/deepmd/pt/model/model/pair_tab.py b/deepmd/pt/model/model/pair_tab.py new file mode 100644 index 0000000000..6f0782289a --- /dev/null +++ b/deepmd/pt/model/model/pair_tab.py @@ -0,0 +1,312 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Dict, + List, + Optional, + Union, +) + +import torch +from torch import ( + nn, +) + +from deepmd.model_format import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.utils.pair_tab import ( + PairTab, +) + +from .atomic_model import ( + AtomicModel, +) + + +class PairTabModel(nn.Module, AtomicModel): + """Pairwise tabulation energy model. + + This model can be used to tabulate the pairwise energy between atoms for either + short-range or long-range interactions, such as D3, LJ, ZBL, etc. It should not + be used alone, but rather as one submodel of a linear (sum) model, such as + DP+D3. + + Do not put the model on the first model of a linear model, since the linear + model fetches the type map from the first model. + + At this moment, the model does not smooth the energy at the cutoff radius, so + one needs to make sure the energy has been smoothed to zero. + + Parameters + ---------- + tab_file : str + The path to the tabulation file. + rcut : float + The cutoff radius. + sel : int or list[int] + The maxmum number of atoms in the cut-off radius. + """ + + def __init__( + self, tab_file: str, rcut: float, sel: Union[int, List[int]], **kwargs + ): + super().__init__() + self.tab_file = tab_file + self.rcut = rcut + + self.tab = PairTab(self.tab_file, rcut=rcut) + self.ntypes = self.tab.ntypes + + tab_info, tab_data = self.tab.get() # this returns -> Tuple[np.array, np.array] + self.tab_info = torch.from_numpy(tab_info) + self.tab_data = torch.from_numpy(tab_data) + + # self.model_type = "ener" + # self.model_version = MODEL_VERSION ## this shoud be in the parent class + + if isinstance(sel, int): + self.sel = sel + elif isinstance(sel, list): + self.sel = sum(sel) + else: + raise TypeError("sel must be int or list[int]") + + def get_fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", shape=[1], reduciable=True, differentiable=True + ) + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> int: + return self.sel + + def distinguish_types(self) -> bool: + # to match DPA1 and DPA2. + return False + + def forward_atomic( + self, + extended_coord, + extended_atype, + nlist, + mapping: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> Dict[str, torch.Tensor]: + self.nframes, self.nloc, self.nnei = nlist.shape + + # this will mask all -1 in the nlist + masked_nlist = torch.clamp(nlist, 0) + + atype = extended_atype[:, : self.nloc] # (nframes, nloc) + pairwise_dr = self._get_pairwise_dist( + extended_coord + ) # (nframes, nall, nall, 3) + pairwise_rr = pairwise_dr.pow(2).sum(-1).sqrt() # (nframes, nall, nall) + + self.tab_data = self.tab_data.reshape( + self.tab.ntypes, self.tab.ntypes, self.tab.nspline, 4 + ) + + # to calculate the atomic_energy, we need 3 tensors, i_type, j_type, rr + # i_type : (nframes, nloc), this is atype. + # j_type : (nframes, nloc, nnei) + j_type = extended_atype[ + torch.arange(extended_atype.size(0))[:, None, None], masked_nlist + ] + + # slice rr to get (nframes, nloc, nnei) + rr = torch.gather(pairwise_rr[:, : self.nloc, :], 2, masked_nlist) + + raw_atomic_energy = self._pair_tabulated_inter(nlist, atype, j_type, rr) + + atomic_energy = 0.5 * torch.sum( + torch.where( + nlist != -1, raw_atomic_energy, torch.zeros_like(raw_atomic_energy) + ), + dim=-1, + ) + + return {"energy": atomic_energy} + + def _pair_tabulated_inter( + self, + nlist: torch.Tensor, + i_type: torch.Tensor, + j_type: torch.Tensor, + rr: torch.Tensor, + ) -> torch.Tensor: + """Pairwise tabulated energy. + + Parameters + ---------- + nlist : torch.Tensor + The unmasked neighbour list. (nframes, nloc) + i_type : torch.Tensor + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : torch.Tensor + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + rr : torch.Tensor + The salar distance vector between two atoms. (nframes, nloc, nnei) + + Returns + ------- + torch.Tensor + The masked atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + + Raises + ------ + Exception + If the distance is beyond the table. + + Notes + ----- + This function is used to calculate the pairwise energy between two atoms. + It uses a table containing cubic spline coefficients calculated in PairTab. + """ + rmin = self.tab_info[0] + hh = self.tab_info[1] + hi = 1.0 / hh + + self.nspline = int(self.tab_info[2] + 0.1) + + uu = (rr - rmin) * hi # this is broadcasted to (nframes,nloc,nnei) + + # if nnei of atom 0 has -1 in the nlist, uu would be 0. + # this is to handle the nlist where the mask is set to 0, so that we don't raise exception for those atoms. + uu = torch.where(nlist != -1, uu, self.nspline + 1) + + if torch.any(uu < 0): + raise Exception("coord go beyond table lower boundary") + + idx = uu.to(torch.int) + + uu -= idx + + table_coef = self._extract_spline_coefficient( + i_type, j_type, idx, self.tab_data, self.nspline + ) + table_coef = table_coef.reshape(self.nframes, self.nloc, self.nnei, 4) + ener = self._calcualte_ener(table_coef, uu) + + # here we need to overwrite energy to zero at rcut and beyond. + mask_beyond_rcut = rr >= self.rcut + # also overwrite values beyond extrapolation to zero + extrapolation_mask = rr >= self.tab.rmin + self.nspline * self.tab.hh + ener[mask_beyond_rcut] = 0 + ener[extrapolation_mask] = 0 + + return ener + + @staticmethod + def _get_pairwise_dist(coords: torch.Tensor) -> torch.Tensor: + """Get pairwise distance `dr`. + + Parameters + ---------- + coords : torch.Tensor + The coordinate of the atoms shape of (nframes * nall * 3). + + Returns + ------- + torch.Tensor + The pairwise distance between the atoms (nframes * nall * nall * 3). + + Examples + -------- + coords = torch.tensor([[ + [0,0,0], + [1,3,5], + [2,4,6] + ]]) + + dist = tensor([[ + [[ 0, 0, 0], + [-1, -3, -5], + [-2, -4, -6]], + + [[ 1, 3, 5], + [ 0, 0, 0], + [-1, -1, -1]], + + [[ 2, 4, 6], + [ 1, 1, 1], + [ 0, 0, 0]] + ]]) + """ + return coords.unsqueeze(2) - coords.unsqueeze(1) + + @staticmethod + def _extract_spline_coefficient( + i_type: torch.Tensor, + j_type: torch.Tensor, + idx: torch.Tensor, + tab_data: torch.Tensor, + nspline: int, + ) -> torch.Tensor: + """Extract the spline coefficient from the table. + + Parameters + ---------- + i_type : torch.Tensor + The integer representation of atom type for all local atoms for all frames. (nframes, nloc) + j_type : torch.Tensor + The integer representation of atom type for all neighbour atoms of all local atoms for all frames. (nframes, nloc, nnei) + idx : torch.Tensor + The index of the spline coefficient. (nframes, nloc, nnei) + tab_data : torch.Tensor + The table storing all the spline coefficient. (ntype, ntype, nspline, 4) + nspline : int + The number of splines in the table. + + Returns + ------- + torch.Tensor + The spline coefficient. (nframes, nloc, nnei, 4), shape may be squeezed. + + """ + # (nframes, nloc, nnei) + expanded_i_type = i_type.unsqueeze(-1).expand(-1, -1, j_type.shape[-1]) + + # (nframes, nloc, nnei, nspline, 4) + expanded_tab_data = tab_data[expanded_i_type, j_type] + + # (nframes, nloc, nnei, 1, 4) + expanded_idx = idx.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, -1, 4) + + # handle the case where idx is beyond the number of splines + clipped_indices = torch.clamp(expanded_idx, 0, nspline - 1).to(torch.int64) + + # (nframes, nloc, nnei, 4) + final_coef = torch.gather(expanded_tab_data, 3, clipped_indices).squeeze() + + # when the spline idx is beyond the table, all spline coefficients are set to `0`, and the resulting ener corresponding to the idx is also `0`. + final_coef[expanded_idx.squeeze() > nspline] = 0 + return final_coef + + @staticmethod + def _calcualte_ener(coef: torch.Tensor, uu: torch.Tensor) -> torch.Tensor: + """Calculate energy using spline coeeficients. + + Parameters + ---------- + coef : torch.Tensor + The spline coefficients. (nframes, nloc, nnei, 4) + uu : torch.Tensor + The atom displancemnt used in interpolation and extrapolation (nframes, nloc, nnei) + + Returns + ------- + torch.Tensor + The atomic energy for all local atoms for all frames. (nframes, nloc, nnei) + """ + a3, a2, a1, a0 = torch.unbind(coef, dim=-1) + etmp = (a3 * uu + a2) * uu + a1 # this should be elementwise operations. + ener = etmp * uu + a0 # this energy has the extrapolated value when rcut > rmax + return ener diff --git a/deepmd/utils/pair_tab.py b/deepmd/utils/pair_tab.py index 4451f53379..56f8e618df 100644 --- a/deepmd/utils/pair_tab.py +++ b/deepmd/utils/pair_tab.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later +import logging from typing import ( + Optional, Tuple, ) @@ -25,11 +27,11 @@ class PairTab: The columes from 2nd to 4th are for 0-0, 0-1 and 1-1 correspondingly. """ - def __init__(self, filename: str) -> None: + def __init__(self, filename: str, rcut: Optional[float] = None) -> None: """Constructor.""" - self.reinit(filename) + self.reinit(filename, rcut) - def reinit(self, filename: str) -> None: + def reinit(self, filename: str, rcut: Optional[float] = None) -> None: """Initialize the tabulated interaction. Parameters @@ -44,8 +46,8 @@ def reinit(self, filename: str) -> None: """ self.vdata = np.loadtxt(filename) self.rmin = self.vdata[0][0] + self.rmax = self.vdata[-1][0] self.hh = self.vdata[1][0] - self.vdata[0][0] - self.nspline = self.vdata.shape[0] - 1 ncol = self.vdata.shape[1] - 1 n0 = (-1 + np.sqrt(1 + 8 * ncol)) * 0.5 self.ntypes = int(n0 + 0.1) @@ -53,13 +55,155 @@ def reinit(self, filename: str) -> None: "number of volumes provided in %s does not match guessed number of types %d" % (filename, self.ntypes) ) + + # check table data against rcut and update tab_file if needed, table upper boundary is used as rcut if not provided. + self.rcut = rcut if rcut is not None else self.rmax + self._check_table_upper_boundary() + self.nspline = ( + self.vdata.shape[0] - 1 + ) # this nspline is updated based on the expanded table. self.tab_info = np.array([self.rmin, self.hh, self.nspline, self.ntypes]) self.tab_data = self._make_data() + def _check_table_upper_boundary(self) -> None: + """Update User Provided Table Based on `rcut`. + + This function checks the upper boundary provided in the table against rcut. + If the table upper boundary values decay to zero before rcut, padding zeros will + be added to the table to cover rcut; if the table upper boundary values do not decay to zero + before ruct, extrapolation will be performed till rcut. + + Examples + -------- + table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0. 1. 1.5 ]] + + rcut = 0.022 + + new_table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0. 1. 1.5 ] + [0.02 0. 0. 0. ] + + ---------------------------------------------- + + table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0.5 1. 1.5 ] + [0.02 0.25 0.4 0.75 ] + [0.025 0. 0.1 0. ] + [0.03 0. 0. 0. ]] + + rcut = 0.031 + + new_table = [[0.005 1. 2. 3. ] + [0.01 0.8 1.6 2.4 ] + [0.015 0.5 1. 1.5 ] + [0.02 0.25 0.4 0.75 ] + [0.025 0. 0.1 0. ] + [0.03 0. 0. 0. ] + [0.035 0. 0. 0. ]] + """ + upper_val = self.vdata[-1][1:] + upper_idx = self.vdata.shape[0] - 1 + self.ncol = self.vdata.shape[1] + + # the index in table for the grid point of rcut, always give the point after rcut. + rcut_idx = int(np.ceil(self.rcut / self.hh - self.rmin / self.hh)) + if np.all(upper_val == 0): + # if table values decay to `0` after rcut + if self.rcut < self.rmax and np.any(self.vdata[rcut_idx - 1][1:] != 0): + logging.warning( + "The energy provided in the table does not decay to 0 at rcut." + ) + # if table values decay to `0` at rcut, do nothing + + # if table values decay to `0` before rcut, pad table with `0`s. + elif self.rcut > self.rmax: + pad_zero = np.zeros((rcut_idx - upper_idx, self.ncol)) + pad_zero[:, 0] = np.linspace( + self.rmax + self.hh, + self.rmax + self.hh * (rcut_idx - upper_idx), + rcut_idx - upper_idx, + ) + self.vdata = np.concatenate((self.vdata, pad_zero), axis=0) + else: + # if table values do not decay to `0` at rcut + if self.rcut <= self.rmax: + logging.warning( + "The energy provided in the table does not decay to 0 at rcut." + ) + # if rcut goes beyond table upper bond, need extrapolation, ensure values decay to `0` before rcut. + else: + logging.warning( + "The rcut goes beyond table upper boundary, performing extrapolation." + ) + pad_extrapolation = np.zeros((rcut_idx - upper_idx, self.ncol)) + + pad_extrapolation[:, 0] = np.linspace( + self.rmax + self.hh, + self.rmax + self.hh * (rcut_idx - upper_idx), + rcut_idx - upper_idx, + ) + # need to calculate table values to fill in with cubic spline + pad_extrapolation = self._extrapolate_table(pad_extrapolation) + + self.vdata = np.concatenate((self.vdata, pad_extrapolation), axis=0) + def get(self) -> Tuple[np.array, np.array]: """Get the serialized table.""" return self.tab_info, self.tab_data + def _extrapolate_table(self, pad_extrapolation: np.array) -> np.array: + """Soomth extrapolation between table upper boundary and rcut. + + This method should only be used when the table upper boundary `rmax` is smaller than `rcut`, and + the table upper boundary values are not zeros. To simplify the problem, we use a single + cubic spline between `rmax` and `rcut` for each pair of atom types. One can substitute this extrapolation + to higher order polynomials if needed. + + There are two scenarios: + 1. `ruct` - `rmax` >= hh: + Set values at the grid point right before `rcut` to 0, and perform exterapolation between + the grid point and `rmax`, this allows smooth decay to 0 at `rcut`. + 2. `rcut` - `rmax` < hh: + Set values at `rmax + hh` to 0, and perform extrapolation between `rmax` and `rmax + hh`. + + Parameters + ---------- + pad_extrapolation : np.array + The emepty grid that holds the extrapolation values. + + Returns + ------- + np.array + The cubic spline extrapolation. + """ + # in theory we should check if the table has at least two rows. + slope = self.vdata[-1, 1:] - self.vdata[-2, 1:] # shape of (ncol-1, ) + + # for extrapolation, we want values decay to `0` prior to `ruct` if possible + # here we try to find the grid point prior to `rcut` + grid_point = ( + -2 if pad_extrapolation[-1, 0] / self.hh - self.rmax / self.hh >= 2 else -1 + ) + temp_grid = np.stack((self.vdata[-1, :], pad_extrapolation[grid_point, :])) + vv = temp_grid[:, 1:] + xx = temp_grid[:, 0] + cs = CubicSpline(xx, vv, bc_type=((1, slope), (1, np.zeros_like(slope)))) + xx_grid = pad_extrapolation[:, 0] + res = cs(xx_grid) + + pad_extrapolation[:, 1:] = res + + # Note: when doing cubic spline, if we want to ensure values decay to zero prior to `rcut` + # this may cause values be positive post `rcut`, we need to overwrite those values to zero + pad_extrapolation = ( + pad_extrapolation if grid_point == -1 else pad_extrapolation[:-1, :] + ) + return pad_extrapolation + def _make_data(self): data = np.zeros([self.ntypes * self.ntypes * 4 * self.nspline]) stride = 4 * self.nspline @@ -68,7 +212,7 @@ def _make_data(self): for t0 in range(self.ntypes): for t1 in range(t0, self.ntypes): vv = self.vdata[:, 1 + idx_iter] - cs = CubicSpline(xx, vv) + cs = CubicSpline(xx, vv, bc_type="clamped") dd = cs(xx, 1) dd *= self.hh dtmp = np.zeros(stride) diff --git a/source/tests/common/test_pairtab_preprocess.py b/source/tests/common/test_pairtab_preprocess.py new file mode 100644 index 0000000000..a866c42236 --- /dev/null +++ b/source/tests/common/test_pairtab_preprocess.py @@ -0,0 +1,263 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np + +from deepmd.utils.pair_tab import ( + PairTab, +) + + +class TestPairTabPreprocessExtrapolate(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.028) + self.tab2 = PairTab(filename=file_path, rcut=0.02) + self.tab3 = PairTab(filename=file_path, rcut=0.022) + self.tab4 = PairTab(filename=file_path, rcut=0.03) + self.tab5 = PairTab(filename=file_path, rcut=0.032) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + # for this test case, the table does not decay to zero at rcut = 0.22, + # in the cubic spline code, we use a fixed size grid, if will be a problem if we introduce variable gird size. + # we will do post process to overwrite spline coefficient `a3`,`a2`,`a1`,`a0`, to ensure energy decays to `0`. + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab5.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.12468, 0.1992, 0.3741], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + +class TestPairTabPreprocessZero(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.023) + self.tab2 = PairTab(filename=file_path, rcut=0.025) + self.tab3 = PairTab(filename=file_path, rcut=0.028) + self.tab4 = PairTab(filename=file_path, rcut=0.033) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + ] + ), + ) + + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.0, 0.0], + [0.03, 0.0, 0.0, 0.0], + [0.035, 0.0, 0.0, 0.0], + ] + ), + ) + + +class TestPairTabPreprocessUneven(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + ] + ) + + self.tab1 = PairTab(filename=file_path, rcut=0.025) + self.tab2 = PairTab(filename=file_path, rcut=0.028) + self.tab3 = PairTab(filename=file_path, rcut=0.03) + self.tab4 = PairTab(filename=file_path, rcut=0.037) + + def test_preprocess(self): + np.testing.assert_allclose( + self.tab1.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + np.testing.assert_allclose( + self.tab2.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab3.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-04, + atol=1e-04, + ) + + np.testing.assert_allclose( + self.tab4.vdata, + np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + [0.025, 0.0, 0.1, 0.0], + [0.03, 0.0, 0.04963, 0.0], + [0.035, 0.0, 0.0, 0.0], + ] + ), + rtol=1e-03, + atol=1e-03, + ) diff --git a/source/tests/pt/test_pairtab.py b/source/tests/pt/test_pairtab.py new file mode 100644 index 0000000000..b4dbda6702 --- /dev/null +++ b/source/tests/pt/test_pairtab.py @@ -0,0 +1,190 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from unittest.mock import ( + patch, +) + +import numpy as np +import torch + +from deepmd.pt.model.model.pair_tab import ( + PairTabModel, +) + + +class TestPairTab(unittest.TestCase): + @patch("numpy.loadtxt") + def setUp(self, mock_loadtxt) -> None: + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0, 2.0, 3.0], + [0.01, 0.8, 1.6, 2.4], + [0.015, 0.5, 1.0, 1.5], + [0.02, 0.25, 0.4, 0.75], + ] + ) + + self.model = PairTabModel(tab_file=file_path, rcut=0.02, sel=2) + + self.extended_coord = torch.tensor( + [ + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.02, 0.01, 0.01], + ], + [ + [0.01, 0.01, 0.01], + [0.01, 0.02, 0.01], + [0.01, 0.01, 0.02], + [0.05, 0.01, 0.01], + ], + ] + ) + + # nframes=2, nall=4 + self.extended_atype = torch.tensor([[0, 1, 0, 1], [0, 0, 1, 1]]) + + # nframes=2, nloc=2, nnei=2 + self.nlist = torch.tensor([[[1, 2], [0, 2]], [[1, 2], [0, 3]]]) + + def test_without_mask(self): + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = torch.tensor([[1.2000, 1.3614], [1.2000, 0.4000]]) + + torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_with_mask(self): + self.nlist = torch.tensor([[[1, -1], [0, 2]], [[1, 2], [0, 3]]]) + + result = self.model.forward_atomic( + self.extended_coord, self.extended_atype, self.nlist + ) + expected_result = torch.tensor([[0.8000, 1.3614], [1.2000, 0.4000]]) + + torch.testing.assert_allclose(result["energy"], expected_result, 0.0001, 0.0001) + + def test_jit(self): + model = torch.jit.script(self.model) + + +class TestPairTabTwoAtoms(unittest.TestCase): + @patch("numpy.loadtxt") + def test_extrapolation_nonzero_rmax(self, mock_loadtxt) -> None: + """Scenarios to test. + + rcut < rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut == rmax: + rr < rcut: use table values, or interpolate. + rr == rcut: use table values, or interpolate. + rr > rcut: should be 0 + rcut > rmax: + rr < rmax: use table values, or interpolate. + rr == rmax: use table values, or interpolate. + rmax < rr < rcut: extrapolate + rr >= rcut: should be 0 + + """ + file_path = "dummy_path" + mock_loadtxt.return_value = np.array( + [ + [0.005, 1.0], + [0.01, 0.8], + [0.015, 0.5], + [0.02, 0.25], + ] + ) + + # nframes=1, nall=2 + extended_atype = torch.tensor([[0, 0]]) + + # nframes=1, nloc=2, nnei=1 + nlist = torch.tensor([[[1], [-1]]]) + + results = [] + + for dist, rcut in zip( + [ + 0.01, + 0.015, + 0.020, + 0.015, + 0.02, + 0.021, + 0.015, + 0.02, + 0.021, + 0.025, + 0.026, + 0.025, + 0.025, + 0.0216161, + ], + [ + 0.015, + 0.015, + 0.015, + 0.02, + 0.02, + 0.02, + 0.022, + 0.022, + 0.022, + 0.025, + 0.025, + 0.03, + 0.035, + 0.025, + ], + ): + extended_coord = torch.tensor( + [ + [ + [0.0, 0.0, 0.0], + [0.0, dist, 0.0], + ], + ] + ) + + model = PairTabModel(tab_file=file_path, rcut=rcut, sel=2) + results.append( + model.forward_atomic(extended_coord, extended_atype, nlist)["energy"] + ) + + expected_result = torch.stack( + [ + torch.tensor( + [ + [ + [0.4, 0], + [0.0, 0], + [0.0, 0], + [0.25, 0], + [0, 0], + [0, 0], + [0.25, 0], + [0.125, 0], + [0.0922, 0], + [0, 0], + [0, 0], + [0, 0], + [0.0923, 0], + [0.0713, 0], + ] + ] + ) + ] + ).reshape(14, 2) + results = torch.stack(results).reshape(14, 2) + + torch.testing.assert_allclose(results, expected_result, 0.0001, 0.0001) + + if __name__ == "__main__": + unittest.main() From 19a8dfbcb17d69110a414dee93be824a5af21e53 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 31 Jan 2024 09:27:32 -0500 Subject: [PATCH 2/3] pt: set nthreads from env (#3205) Signed-off-by: Jinzhe Zeng --- deepmd/env.py | 83 +++++++++++++++++++++++ deepmd/pt/utils/env.py | 13 ++++ deepmd/tf/env.py | 67 ++---------------- doc/troubleshooting/howtoset_num_nodes.md | 37 +++++++--- source/api_cc/include/common.h | 4 +- source/api_cc/src/common.cc | 19 +++++- source/tests/tf/test_env.py | 4 +- 7 files changed, 149 insertions(+), 78 deletions(-) diff --git a/deepmd/env.py b/deepmd/env.py index b1d4958ed8..1a8da63f8e 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import logging import os +from typing import ( + Tuple, +) import numpy as np @@ -26,3 +30,82 @@ "low. Please set precision with environmental variable " "DP_INTERFACE_PREC." % dp_float_prec ) + + +def set_env_if_empty(key: str, value: str, verbose: bool = True): + """Set environment variable only if it is empty. + + Parameters + ---------- + key : str + env variable name + value : str + env variable value + verbose : bool, optional + if True action will be logged, by default True + """ + if os.environ.get(key) is None: + os.environ[key] = value + if verbose: + logging.warning( + f"Environment variable {key} is empty. Use the default value {value}" + ) + + +def set_default_nthreads(): + """Set internal number of threads to default=automatic selection. + + Notes + ----- + `DP_INTRA_OP_PARALLELISM_THREADS` and `DP_INTER_OP_PARALLELISM_THREADS` + control configuration of multithreading. + """ + if ( + "OMP_NUM_THREADS" not in os.environ + # for backward compatibility + or ( + "DP_INTRA_OP_PARALLELISM_THREADS" not in os.environ + and "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ + ) + or ( + "DP_INTER_OP_PARALLELISM_THREADS" not in os.environ + and "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ + ) + ): + logging.warning( + "To get the best performance, it is recommended to adjust " + "the number of threads by setting the environment variables " + "OMP_NUM_THREADS, DP_INTRA_OP_PARALLELISM_THREADS, and " + "DP_INTER_OP_PARALLELISM_THREADS. See " + "https://deepmd.rtfd.io/parallelism/ for more information." + ) + if "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ: + set_env_if_empty("DP_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False) + if "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ: + set_env_if_empty("DP_INTER_OP_PARALLELISM_THREADS", "0", verbose=False) + + +def get_default_nthreads() -> Tuple[int, int]: + """Get paralellism settings. + + The method will first read the environment variables with the prefix `DP_`. + If not found, it will read the environment variables with the prefix `TF_` + for backward compatibility. + + Returns + ------- + Tuple[int, int] + number of `DP_INTRA_OP_PARALLELISM_THREADS` and + `DP_INTER_OP_PARALLELISM_THREADS` + """ + return int( + os.environ.get( + "DP_INTRA_OP_PARALLELISM_THREADS", + os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + ) + ), int( + os.environ.get( + "DP_INTER_OP_PARALLELISM_THREADS", + os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + ) + ) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 559dba0167..b51b03fdc2 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -4,6 +4,11 @@ import numpy as np import torch +from deepmd.env import ( + get_default_nthreads, + set_default_nthreads, +) + PRECISION = os.environ.get("PRECISION", "float64") GLOBAL_NP_FLOAT_PRECISION = getattr(np, PRECISION) GLOBAL_PT_FLOAT_PRECISION = getattr(torch, PRECISION) @@ -37,3 +42,11 @@ "double": torch.float64, } DEFAULT_PRECISION = "float64" + +# throw warnings if threads not set +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +if inter_nthreads > 0: # the behavior of 0 is not documented + torch.set_num_interop_threads(inter_nthreads) +if intra_nthreads > 0: + torch.set_num_threads(intra_nthreads) diff --git a/deepmd/tf/env.py b/deepmd/tf/env.py index 993768c4a4..6bc89664c7 100644 --- a/deepmd/tf/env.py +++ b/deepmd/tf/env.py @@ -2,7 +2,6 @@ """Module that sets tensorflow working environment and exports inportant constants.""" import ctypes -import logging import os import platform from configparser import ( @@ -19,7 +18,6 @@ TYPE_CHECKING, Any, Dict, - Tuple, ) import numpy as np @@ -31,8 +29,15 @@ from deepmd.env import ( GLOBAL_ENER_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.env import get_default_nthreads as get_tf_default_nthreads +from deepmd.env import ( global_float_prec, ) +from deepmd.env import set_default_nthreads as set_tf_default_nthreads +from deepmd.env import ( + set_env_if_empty, +) if TYPE_CHECKING: from types import ( @@ -216,26 +221,6 @@ def dlopen_library(module: str, filename: str): } -def set_env_if_empty(key: str, value: str, verbose: bool = True): - """Set environment variable only if it is empty. - - Parameters - ---------- - key : str - env variable name - value : str - env variable value - verbose : bool, optional - if True action will be logged, by default True - """ - if os.environ.get(key) is None: - os.environ[key] = value - if verbose: - logging.warning( - f"Environment variable {key} is empty. Use the default value {value}" - ) - - def set_mkl(): """Tuning MKL for the best performance. @@ -270,44 +255,6 @@ def set_mkl(): reload(np) -def set_tf_default_nthreads(): - """Set TF internal number of threads to default=automatic selection. - - Notes - ----- - `TF_INTRA_OP_PARALLELISM_THREADS` and `TF_INTER_OP_PARALLELISM_THREADS` - control TF configuration of multithreading. - """ - if ( - "OMP_NUM_THREADS" not in os.environ - or "TF_INTRA_OP_PARALLELISM_THREADS" not in os.environ - or "TF_INTER_OP_PARALLELISM_THREADS" not in os.environ - ): - logging.warning( - "To get the best performance, it is recommended to adjust " - "the number of threads by setting the environment variables " - "OMP_NUM_THREADS, TF_INTRA_OP_PARALLELISM_THREADS, and " - "TF_INTER_OP_PARALLELISM_THREADS. See " - "https://deepmd.rtfd.io/parallelism/ for more information." - ) - set_env_if_empty("TF_INTRA_OP_PARALLELISM_THREADS", "0", verbose=False) - set_env_if_empty("TF_INTER_OP_PARALLELISM_THREADS", "0", verbose=False) - - -def get_tf_default_nthreads() -> Tuple[int, int]: - """Get TF paralellism settings. - - Returns - ------- - Tuple[int, int] - number of `TF_INTRA_OP_PARALLELISM_THREADS` and - `TF_INTER_OP_PARALLELISM_THREADS` - """ - return int(os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0")), int( - os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0") - ) - - def get_tf_session_config() -> Any: """Configure tensorflow session. diff --git a/doc/troubleshooting/howtoset_num_nodes.md b/doc/troubleshooting/howtoset_num_nodes.md index 8a9beab857..18b1a133ee 100644 --- a/doc/troubleshooting/howtoset_num_nodes.md +++ b/doc/troubleshooting/howtoset_num_nodes.md @@ -22,10 +22,10 @@ Sometimes, `$num_nodes` and the nodes information can be directly given by the H ## Parallelism between independent operators -For CPU devices, TensorFlow use multiple streams to run independent operators (OP). +For CPU devices, TensorFlow and PyTorch use multiple streams to run independent operators (OP). ```bash -export TF_INTER_OP_PARALLELISM_THREADS=3 +export DP_INTER_OP_PARALLELISM_THREADS=3 ``` However, for GPU devices, TensorFlow uses only one compute stream and multiple copy streams. @@ -33,20 +33,35 @@ Note that some of DeePMD-kit OPs do not have GPU support, so it is still encoura ## Parallelism within an individual operators -For CPU devices, `TF_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow native OPs when TensorFlow is built against Eigen. +For CPU devices, `DP_INTRA_OP_PARALLELISM_THREADS` controls parallelism within TensorFlow (when TensorFlow is built against Eigen) and PyTorch native OPs. ```bash -export TF_INTRA_OP_PARALLELISM_THREADS=2 +export DP_INTRA_OP_PARALLELISM_THREADS=2 ``` -`OMP_NUM_THREADS` is threads for OpenMP parallelism. It controls parallelism within TensorFlow native OPs when TensorFlow is built by Intel OneDNN and DeePMD-kit custom CPU OPs. -It may also control parallelsim for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable. +`OMP_NUM_THREADS` is the number of threads for OpenMP parallelism. +It controls parallelism within TensorFlow (when TensorFlow is built upon Intel OneDNN) and PyTorch (when PyTorch is built upon OpenMP) native OPs and DeePMD-kit custom CPU OPs. +It may also control parallelism for NumPy when NumPy is built against OpenMP, so one who uses GPUs for training should also care this environmental variable. ```bash export OMP_NUM_THREADS=2 ``` -There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`. See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information. +There are several other environmental variables for OpenMP, such as `KMP_BLOCKTIME`. + +::::{tab-set} + +:::{tab-item} TensorFlow {{ tensorflow_icon }} + +See [Intel documentation](https://www.intel.com/content/www/us/en/developer/articles/technical/maximize-tensorflow-performance-on-cpu-considerations-and-recommendations-for-inference.html) for detailed information. + +::: +:::{tab-item} PyTorch {{ pytorch_icon }} + +See [PyTorch documentation](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html) for detailed information. + +::: +:::: ## Tune the performance @@ -56,8 +71,8 @@ Here are some empirical examples. If you wish to use 3 cores of 2 CPUs on one node, you may set the environmental variables and run DeePMD-kit as follows: ```bash export OMP_NUM_THREADS=3 -export TF_INTRA_OP_PARALLELISM_THREADS=3 -export TF_INTER_OP_PARALLELISM_THREADS=2 +export DP_INTRA_OP_PARALLELISM_THREADS=3 +export DP_INTER_OP_PARALLELISM_THREADS=2 dp train input.json ``` @@ -65,8 +80,8 @@ For a node with 128 cores, it is recommended to start with the following variabl ```bash export OMP_NUM_THREADS=16 -export TF_INTRA_OP_PARALLELISM_THREADS=16 -export TF_INTER_OP_PARALLELISM_THREADS=8 +export DP_INTRA_OP_PARALLELISM_THREADS=16 +export DP_INTER_OP_PARALLELISM_THREADS=8 ``` Again, in general, one should make sure the product of the parallel numbers is less than or equal to the number of cores available. diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index 0392747979..72382169f8 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -144,9 +144,9 @@ void select_map_inv(typename std::vector::iterator out, * @brief Get the number of threads from the environment variable. * @details A warning will be thrown if environmental variables are not set. * @param[out] num_intra_nthreads The number of intra threads. Read from - *TF_INTRA_OP_PARALLELISM_THREADS. + *DP_INTRA_OP_PARALLELISM_THREADS. * @param[out] num_inter_nthreads The number of inter threads. Read from - *TF_INTER_OP_PARALLELISM_THREADS. + *DP_INTER_OP_PARALLELISM_THREADS. **/ void get_env_nthreads(int& num_intra_nthreads, int& num_inter_nthreads); diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 2923534fb7..d2923c8d9e 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -330,23 +330,36 @@ void deepmd::get_env_nthreads(int& num_intra_nthreads, num_intra_nthreads = 0; num_inter_nthreads = 0; const char* env_intra_nthreads = - std::getenv("TF_INTRA_OP_PARALLELISM_THREADS"); + std::getenv("DP_INTRA_OP_PARALLELISM_THREADS"); const char* env_inter_nthreads = + std::getenv("DP_INTER_OP_PARALLELISM_THREADS"); + // backward compatibility + const char* env_intra_nthreads_tf = + std::getenv("TF_INTRA_OP_PARALLELISM_THREADS"); + const char* env_inter_nthreads_tf = std::getenv("TF_INTER_OP_PARALLELISM_THREADS"); const char* env_omp_nthreads = std::getenv("OMP_NUM_THREADS"); if (env_intra_nthreads && std::string(env_intra_nthreads) != std::string("") && atoi(env_intra_nthreads) >= 0) { num_intra_nthreads = atoi(env_intra_nthreads); + } else if (env_intra_nthreads_tf && + std::string(env_intra_nthreads_tf) != std::string("") && + atoi(env_intra_nthreads_tf) >= 0) { + num_intra_nthreads = atoi(env_intra_nthreads_tf); } else { - throw_env_not_set_warning("TF_INTRA_OP_PARALLELISM_THREADS"); + throw_env_not_set_warning("DP_INTRA_OP_PARALLELISM_THREADS"); } if (env_inter_nthreads && std::string(env_inter_nthreads) != std::string("") && atoi(env_inter_nthreads) >= 0) { num_inter_nthreads = atoi(env_inter_nthreads); + } else if (env_inter_nthreads_tf && + std::string(env_inter_nthreads_tf) != std::string("") && + atoi(env_inter_nthreads_tf) >= 0) { + num_inter_nthreads = atoi(env_inter_nthreads_tf); } else { - throw_env_not_set_warning("TF_INTER_OP_PARALLELISM_THREADS"); + throw_env_not_set_warning("DP_INTER_OP_PARALLELISM_THREADS"); } if (!(env_omp_nthreads && std::string(env_omp_nthreads) != std::string("") && atoi(env_omp_nthreads) >= 0)) { diff --git a/source/tests/tf/test_env.py b/source/tests/tf/test_env.py index eb1b40e707..cd066b06a5 100644 --- a/source/tests/tf/test_env.py +++ b/source/tests/tf/test_env.py @@ -19,8 +19,8 @@ def test_empty(self): @mock.patch.dict( "os.environ", values={ - "TF_INTRA_OP_PARALLELISM_THREADS": "5", - "TF_INTER_OP_PARALLELISM_THREADS": "3", + "DP_INTRA_OP_PARALLELISM_THREADS": "5", + "DP_INTER_OP_PARALLELISM_THREADS": "3", }, ) def test_given(self): From 032fa7d1f0f87a3b33c6913e41567a7a8a908cbf Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 31 Jan 2024 09:27:49 -0500 Subject: [PATCH 3/3] pt: add tensorboard and profiler support (#3204) Use the same arguments as TF. [PyTorch on Tensorboard](https://pytorch.org/docs/stable/tensorboard.html): ![1706608497314](https://github.com/deepmodeling/deepmd-kit/assets/9496702/9d747ee2-2e76-43d3-8252-7dbd0cea6768) [PyTorch Profiler on Tensorboard](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html): ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/929d69b7-a696-45b1-8e9b-2b491177ad95) --------- Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 36 ++++++++++++++++++++++++++++++++++++ deepmd/utils/argcheck.py | 2 +- doc/train/tensorboard.md | 4 ++-- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index e4c672765b..ee0e7a54cc 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -438,6 +438,12 @@ def warm_up_linear(step, warmup_steps): assert sum_prob > 0.0, "Sum of model prob must be larger than 0!" self.model_prob = self.model_prob / sum_prob + # Tensorboard + self.enable_tensorboard = training_params.get("tensorboard", False) + self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") + self.tensorboard_freq = training_params.get("tensorboard_freq", 1) + self.enable_profiler = training_params.get("enable_profiler", False) + def run(self): fout = ( open(self.disp_file, mode="w", buffering=1) if self.rank == 0 else None @@ -448,8 +454,27 @@ def run(self): logging.info("Start to train %d steps.", self.num_steps) if dist.is_initialized(): logging.info(f"Rank: {dist.get_rank()}/{dist.get_world_size()}") + if self.enable_tensorboard: + from torch.utils.tensorboard import ( + SummaryWriter, + ) + + writer = SummaryWriter(log_dir=self.tensorboard_log_dir) + if self.enable_profiler: + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + self.tensorboard_log_dir + ), + record_shapes=True, + with_stack=True, + ) + prof.start() def step(_step_id, task_key="Default"): + # PyTorch Profiler + if self.enable_profiler: + prof.step() self.wrapper.train() if isinstance(self.lr_exp, dict): _lr = self.lr_exp[task_key] @@ -654,6 +679,13 @@ def log_loss_valid(_task_key="Default"): with open("checkpoint", "w") as f: f.write(str(self.latest_model)) + # tensorboard + if self.enable_tensorboard and _step_id % self.tensorboard_freq == 0: + writer.add_scalar(f"{task_key}/lr", cur_lr, _step_id) + writer.add_scalar(f"{task_key}/loss", loss, _step_id) + for item in more_loss: + writer.add_scalar(f"{task_key}/{item}", more_loss[item], _step_id) + self.t0 = time.time() for step_id in range(self.num_steps): if step_id < self.start_step: @@ -691,6 +723,10 @@ def log_loss_valid(_task_key="Default"): fout.close() if SAMPLER_RECORD: fout1.close() + if self.enable_tensorboard: + writer.close() + if self.enable_profiler: + prof.stop() def save_model(self, save_path, lr=0.0, step=0): module = self.wrapper.module if dist.is_initialized() else self.wrapper diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 31b54b4d76..dbe4881952 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1703,7 +1703,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_time_training = "Timing durining training." doc_profiling = "Profiling during training." doc_profiling_file = "Output file for profiling." - doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) to analyze performance. The log will be saved to `tensorboard_log_dir`." + doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler to analyze performance. The log will be saved to `tensorboard_log_dir`." doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events." diff --git a/doc/train/tensorboard.md b/doc/train/tensorboard.md index 1d6c5f0d68..a6cfdccb68 100644 --- a/doc/train/tensorboard.md +++ b/doc/train/tensorboard.md @@ -1,7 +1,7 @@ -# TensorBoard Usage {{ tensorflow_icon }} +# TensorBoard Usage {{ tensorflow_icon }} {{ pytorch_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }} ::: TensorBoard provides the visualization and tooling needed for machine learning