From c12bc0108cef70ab9eb6b8350cc8b4b7a31ad9f8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 9 Nov 2024 03:42:28 -0500 Subject: [PATCH] feat(pt): calculate stat during compression if `--skip-neighbor-stat` (#4330) If `--skip-neighbor-stat` is set during training, when calling `dp compress`, first calculate the neighbor stat. ## Summary by CodeRabbit - **New Features** - Enhanced `enable_compression` function to accept a `training_script` parameter for improved error handling and functionality. - Updated the `compress` command to allow specification of a training script during execution. - Introduced a new testing framework for models using the `--skip-neighbor-stat` flag, validating their functionality. - **Bug Fixes** - Improved error handling for cases where the model's minimum neighbor distance is not saved. - **Tests** - Added a new test class and methods to validate the functionality of models initialized with skip neighbor statistics. Signed-off-by: Jinzhe Zeng --- deepmd/pt/entrypoints/compress.py | 53 ++++++ deepmd/pt/entrypoints/main.py | 1 + .../tests/pt/test_model_compression_se_a.py | 159 +++++++++++++++++- 3 files changed, 212 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/entrypoints/compress.py b/deepmd/pt/entrypoints/compress.py index 1042af3335..d94a34215c 100644 --- a/deepmd/pt/entrypoints/compress.py +++ b/deepmd/pt/entrypoints/compress.py @@ -1,11 +1,32 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +import logging +from typing import ( + Optional, +) import torch +from deepmd.common import ( + j_loader, +) from deepmd.pt.model.model import ( get_model, ) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.update_sel import ( + UpdateSel, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data_system import ( + get_data, +) + +log = logging.getLogger(__name__) def enable_compression( @@ -14,12 +35,44 @@ def enable_compression( stride: float = 0.01, extrapolate: int = 5, check_frequency: int = -1, + training_script: Optional[str] = None, ): saved_model = torch.jit.load(input_file, map_location="cpu") model_def_script = json.loads(saved_model.model_def_script) model = get_model(model_def_script) model.load_state_dict(saved_model.state_dict()) + if model.get_min_nbor_dist() is None: + log.info( + "Minimal neighbor distance is not saved in the model, compute it from the training data." + ) + if training_script is None: + raise ValueError( + "The model does not have a minimum neighbor distance, " + "so the training script and data must be provided " + "(via -t,--training-script)." + ) + + jdata = j_loader(training_script) + jdata = update_deepmd_input(jdata) + + type_map = jdata["model"].get("type_map", None) + train_data = get_data( + jdata["training"]["training_data"], + 0, # not used + type_map, + None, + ) + update_sel = UpdateSel() + t_min_nbor_dist = update_sel.get_min_nbor_dist( + train_data, + ) + model.min_nbor_dist = torch.tensor( + t_min_nbor_dist, + dtype=env.GLOBAL_PT_FLOAT_PRECISION, + device=env.DEVICE, + ) + model.enable_compression( extrapolate, stride, diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 7daa29d0f9..fe85a3301c 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -565,6 +565,7 @@ def main(args: Optional[Union[list[str], argparse.Namespace]] = None): stride=FLAGS.step, extrapolate=FLAGS.extrapolate, check_frequency=FLAGS.frequency, + training_script=FLAGS.training_script, ) else: raise RuntimeError(f"Invalid command {FLAGS.command}!") diff --git a/source/tests/pt/test_model_compression_se_a.py b/source/tests/pt/test_model_compression_se_a.py index 0e7bf0b69a..3738b61c13 100644 --- a/source/tests/pt/test_model_compression_se_a.py +++ b/source/tests/pt/test_model_compression_se_a.py @@ -74,6 +74,34 @@ def _init_models_exclude_types(): return INPUT, frozen_model, compressed_model +def _init_models_skip_neighbor_stat(): + suffix = "-skip-neighbor-stat" + data_file = str(tests_path / os.path.join("model_compression", "data")) + frozen_model = str(tests_path / f"dp-original{suffix}.pth") + compressed_model = str(tests_path / f"dp-compressed{suffix}.pth") + INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + jdata["training"]["training_data"]["systems"] = data_file + with open(INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + ret = run_dp("dp --pt train " + INPUT + " --skip-neighbor-stat") + np.testing.assert_equal(ret, 0, "DP train failed!") + ret = run_dp("dp --pt freeze -o " + frozen_model) + np.testing.assert_equal(ret, 0, "DP freeze failed!") + ret = run_dp( + "dp --pt compress " + + " -i " + + frozen_model + + " -o " + + compressed_model + + " -t " + + INPUT + ) + np.testing.assert_equal(ret, 0, "DP model compression failed!") + return INPUT, frozen_model, compressed_model + + def setUpModule(): global \ INPUT, \ @@ -81,8 +109,13 @@ def setUpModule(): COMPRESSED_MODEL, \ INPUT_ET, \ FROZEN_MODEL_ET, \ - COMPRESSED_MODEL_ET + COMPRESSED_MODEL_ET, \ + FROZEN_MODEL_SKIP_NEIGHBOR_STAT, \ + COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT INPUT, FROZEN_MODEL, COMPRESSED_MODEL = _init_models() + _, FROZEN_MODEL_SKIP_NEIGHBOR_STAT, COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT = ( + _init_models_skip_neighbor_stat() + ) INPUT_ET, FROZEN_MODEL_ET, COMPRESSED_MODEL_ET = _init_models_exclude_types() @@ -572,5 +605,129 @@ def test_2frame_atm(self): np.testing.assert_almost_equal(vv0, vv1, default_places) +class TestSkipNeighborStat(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.dp_original = DeepEval(FROZEN_MODEL_SKIP_NEIGHBOR_STAT) + cls.dp_compressed = DeepEval(COMPRESSED_MODEL_SKIP_NEIGHBOR_STAT) + cls.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ] + ) + cls.atype = [0, 1, 1, 0, 1, 1] + cls.box = np.array([13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0]) + + def test_attrs(self): + self.assertEqual(self.dp_original.get_ntypes(), 2) + self.assertAlmostEqual(self.dp_original.get_rcut(), 6.0, places=default_places) + self.assertEqual(self.dp_original.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_original.get_dim_fparam(), 0) + self.assertEqual(self.dp_original.get_dim_aparam(), 0) + + self.assertEqual(self.dp_compressed.get_ntypes(), 2) + self.assertAlmostEqual( + self.dp_compressed.get_rcut(), 6.0, places=default_places + ) + self.assertEqual(self.dp_compressed.get_type_map(), ["O", "H"]) + self.assertEqual(self.dp_compressed.get_dim_fparam(), 0) + self.assertEqual(self.dp_compressed.get_dim_aparam(), 0) + + def test_1frame(self): + ee0, ff0, vv0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=False + ) + ee1, ff1, vv1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=False + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_1frame_atm(self): + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + self.coords, self.box, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + self.coords, self.box, self.atype, atomic=True + ) + # check shape of the returns + nframes = 1 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + def test_2frame_atm(self): + coords2 = np.concatenate((self.coords, self.coords)) + box2 = np.concatenate((self.box, self.box)) + ee0, ff0, vv0, ae0, av0 = self.dp_original.eval( + coords2, box2, self.atype, atomic=True + ) + ee1, ff1, vv1, ae1, av1 = self.dp_compressed.eval( + coords2, box2, self.atype, atomic=True + ) + # check shape of the returns + nframes = 2 + natoms = len(self.atype) + self.assertEqual(ee0.shape, (nframes, 1)) + self.assertEqual(ff0.shape, (nframes, natoms, 3)) + self.assertEqual(vv0.shape, (nframes, 9)) + self.assertEqual(ae0.shape, (nframes, natoms, 1)) + self.assertEqual(av0.shape, (nframes, natoms, 9)) + self.assertEqual(ee1.shape, (nframes, 1)) + self.assertEqual(ff1.shape, (nframes, natoms, 3)) + self.assertEqual(vv1.shape, (nframes, 9)) + self.assertEqual(ae1.shape, (nframes, natoms, 1)) + self.assertEqual(av1.shape, (nframes, natoms, 9)) + + # check values + np.testing.assert_almost_equal(ff0, ff1, default_places) + np.testing.assert_almost_equal(ae0, ae1, default_places) + np.testing.assert_almost_equal(av0, av1, default_places) + np.testing.assert_almost_equal(ee0, ee1, default_places) + np.testing.assert_almost_equal(vv0, vv1, default_places) + + if __name__ == "__main__": unittest.main()