From 38815b371162a2153b0c2b24a38867825665c7a3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 4 Nov 2024 15:15:03 -0500 Subject: [PATCH] feat(jax): export call_lower to SavedModel via jax2tf (#4254) ## Summary by CodeRabbit ## Release Notes - **New Features** - Added support for the TensorFlow SavedModel format, allowing users to handle additional model file types. - Introduced a new TensorFlow model wrapper class for enhanced integration with JAX functionalities. - **Bug Fixes** - Improved error handling for unsupported file formats during model deserialization. - **Documentation** - Updated backend documentation to reflect new file extensions and clarify backend capabilities. - **Tests** - Enhanced test structure for better clarity and maintainability regarding backend handling. - Added a new job for testing TensorFlow 2 in eager mode within the testing workflow. - Introduced a conditional skip for tests based on TensorFlow 2 compatibility. --------- Signed-off-by: Jinzhe Zeng --- .github/workflows/test_python.yml | 18 +- deepmd/backend/jax.py | 2 +- deepmd/jax/infer/deep_eval.py | 27 ++- deepmd/jax/jax2tf/__init__.py | 11 + deepmd/jax/jax2tf/serialization.py | 172 ++++++++++++++ deepmd/jax/jax2tf/tfmodel.py | 325 ++++++++++++++++++++++++++ deepmd/jax/utils/serialization.py | 12 +- doc/backend.md | 3 +- pyproject.toml | 1 + source/tests/consistent/io/test_io.py | 18 +- source/tests/utils.py | 1 + 11 files changed, 568 insertions(+), 22 deletions(-) create mode 100644 deepmd/jax/jax2tf/__init__.py create mode 100644 deepmd/jax/jax2tf/serialization.py create mode 100644 deepmd/jax/jax2tf/tfmodel.py diff --git a/.github/workflows/test_python.yml b/.github/workflows/test_python.yml index e46bddd98a..422dcb5f17 100644 --- a/.github/workflows/test_python.yml +++ b/.github/workflows/test_python.yml @@ -25,19 +25,23 @@ jobs: python-version: ${{ matrix.python }} - run: python -m pip install -U uv - run: | - source/install/uv_with_retry.sh pip install --system mpich + source/install/uv_with_retry.sh pip install --system openmpi tensorflow-cpu source/install/uv_with_retry.sh pip install --system torch -i https://download.pytorch.org/whl/cpu + export TENSORFLOW_ROOT=$(python -c 'import tensorflow;print(tensorflow.__path__[0])') export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') - source/install/uv_with_retry.sh pip install --system --only-binary=horovod -e .[cpu,test,jax] horovod[tensorflow-cpu] mpi4py + source/install/uv_with_retry.sh pip install --system -e .[test,jax] mpi4py + source/install/uv_with_retry.sh pip install --system horovod --no-build-isolation env: # Please note that uv has some issues with finding # existing TensorFlow package. Currently, it uses # TensorFlow in the build dependency, but if it # changes, setting `TENSORFLOW_ROOT`. - TENSORFLOW_VERSION: 2.16.1 DP_ENABLE_PYTORCH: 1 DP_BUILD_TESTING: 1 - UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/njzjz/simple https://pypi.anaconda.org/mpi4py/simple" + UV_EXTRA_INDEX_URL: "https://pypi.anaconda.org/mpi4py/simple" + HOROVOD_WITH_TENSORFLOW: 1 + HOROVOD_WITHOUT_PYTORCH: 1 + HOROVOD_WITH_MPI: 1 - run: dp --version - name: Get durations from cache uses: actions/cache@v4 @@ -53,6 +57,12 @@ jobs: - run: pytest --cov=deepmd source/tests --durations=0 --splits 6 --group ${{ matrix.group }} --store-durations --durations-path=.test_durations --splitting-algorithm least_duration env: NUM_WORKERS: 0 + - name: Test TF2 eager mode + run: pytest --cov=deepmd source/tests/consistent/io/test_io.py --durations=0 + env: + NUM_WORKERS: 0 + DP_TEST_TF2_ONLY: 1 + if: matrix.group == 1 - run: mv .test_durations .test_durations_${{ matrix.group }} - name: Upload partial durations uses: actions/upload-artifact@v4 diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index cfb0936bda..7a714c2090 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".hlo", ".jax"] + suffixes: ClassVar[list[str]] = [".hlo", ".jax", ".savedmodel"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index b60076c68c..fc526a502e 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -90,15 +90,24 @@ def __init__( self.output_def = output_def self.model_path = model_file - model_data = load_dp_model(model_file) - self.dp = HLO( - stablehlo=model_data["@variables"]["stablehlo"].tobytes(), - stablehlo_atomic_virial=model_data["@variables"][ - "stablehlo_atomic_virial" - ].tobytes(), - model_def_script=model_data["model_def_script"], - **model_data["constants"], - ) + if model_file.endswith(".hlo"): + model_data = load_dp_model(model_file) + self.dp = HLO( + stablehlo=model_data["@variables"]["stablehlo"].tobytes(), + stablehlo_atomic_virial=model_data["@variables"][ + "stablehlo_atomic_virial" + ].tobytes(), + model_def_script=model_data["model_def_script"], + **model_data["constants"], + ) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.tfmodel import ( + TFModelWrapper, + ) + + self.dp = TFModelWrapper(model_file) + else: + raise ValueError("Unsupported file extension") self.rcut = self.dp.get_rcut() self.type_map = self.dp.get_type_map() if isinstance(auto_batch_size, bool): diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py new file mode 100644 index 0000000000..88a928f04d --- /dev/null +++ b/deepmd/jax/jax2tf/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +if not tf.executing_eagerly(): + # TF disallow temporary eager execution + raise RuntimeError( + "Unfortunatly, jax2tf (requires eager execution) cannot be used with the " + "TensorFlow backend (disables eager execution). " + "If you are converting a model between different backends, " + "considering converting to the `.dp` format first." + ) diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py new file mode 100644 index 0000000000..dff43a11fc --- /dev/null +++ b/deepmd/jax/jax2tf/serialization.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json + +import tensorflow as tf +from jax.experimental import ( + jax2tf, +) + +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a model file. + + Parameters + ---------- + model_file : str + The model file to be saved. + data : dict + The dictionary to be deserialized. + """ + if model_file.endswith(".savedmodel"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + call_lower = model.call_lower + + tf_model = tf.Module() + + def exported_whether_do_atomic_virial(do_atomic_virial): + def call_lower_with_fixed_do_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam + ): + return call_lower( + coord, + atype, + nlist, + mapping, + fparam, + aparam, + do_atomic_virial=do_atomic_virial, + ) + + return jax2tf.convert( + call_lower_with_fixed_do_atomic_virial, + polymorphic_shapes=[ + "(nf, nloc + nghost, 3)", + "(nf, nloc + nghost)", + f"(nf, nloc, {model.get_nnei()})", + "(nf, nloc + nghost)", + f"(nf, {model.get_dim_fparam()})", + f"(nf, nloc, {model.get_dim_aparam()})", + ], + with_gradient=True, + ) + + # Save a function that can take scalar inputs. + # We need to explicit set the function name, so C++ can find it. + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_without_atomic_virial( + coord, atype, nlist, mapping, fparam, aparam + ): + return exported_whether_do_atomic_virial(do_atomic_virial=False)( + coord, atype, nlist, mapping, fparam, aparam + ) + + tf_model.call_lower = call_lower_without_atomic_virial + + @tf.function( + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int32), + tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), + tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), + ], + ) + def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + return exported_whether_do_atomic_virial(do_atomic_virial=True)( + coord, atype, nlist, mapping, fparam, aparam + ) + + tf_model.call_lower_atomic_virial = call_lower_with_atomic_virial + + # set functions to export other attributes + @tf.function + def get_type_map(): + return tf.constant(model.get_type_map(), dtype=tf.string) + + tf_model.get_type_map = get_type_map + + @tf.function + def get_rcut(): + return tf.constant(model.get_rcut(), dtype=tf.double) + + tf_model.get_rcut = get_rcut + + @tf.function + def get_dim_fparam(): + return tf.constant(model.get_dim_fparam(), dtype=tf.int64) + + tf_model.get_dim_fparam = get_dim_fparam + + @tf.function + def get_dim_aparam(): + return tf.constant(model.get_dim_aparam(), dtype=tf.int64) + + tf_model.get_dim_aparam = get_dim_aparam + + @tf.function + def get_sel_type(): + return tf.constant(model.get_sel_type(), dtype=tf.int64) + + tf_model.get_sel_type = get_sel_type + + @tf.function + def is_aparam_nall(): + return tf.constant(model.is_aparam_nall(), dtype=tf.bool) + + tf_model.is_aparam_nall = is_aparam_nall + + @tf.function + def model_output_type(): + return tf.constant(model.model_output_type(), dtype=tf.string) + + tf_model.model_output_type = model_output_type + + @tf.function + def mixed_types(): + return tf.constant(model.mixed_types(), dtype=tf.bool) + + tf_model.mixed_types = mixed_types + + if model.get_min_nbor_dist() is not None: + + @tf.function + def get_min_nbor_dist(): + return tf.constant(model.get_min_nbor_dist(), dtype=tf.double) + + tf_model.get_min_nbor_dist = get_min_nbor_dist + + @tf.function + def get_sel(): + return tf.constant(model.get_sel(), dtype=tf.int64) + + tf_model.get_sel = get_sel + + @tf.function + def get_model_def_script(): + return tf.constant( + json.dumps(model_def_script, separators=(",", ":")), dtype=tf.string + ) + + tf_model.get_model_def_script = get_model_def_script + tf.saved_model.save( + tf_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) diff --git a/deepmd/jax/jax2tf/tfmodel.py b/deepmd/jax/jax2tf/tfmodel.py new file mode 100644 index 0000000000..8f04014a97 --- /dev/null +++ b/deepmd/jax/jax2tf/tfmodel.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import jax.experimental.jax2tf as jax2tf +import tensorflow as tf + +from deepmd.dpmodel.model.make_model import ( + model_call_from_call_lower, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableDef, +) +from deepmd.jax.env import ( + jnp, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) + +OUTPUT_DEFS = { + "energy": OutputVariableDef( + "energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + "mask": OutputVariableDef( + "mask", + shape=[1], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), +} + + +def decode_list_of_bytes(list_of_bytes: list[bytes]) -> list[str]: + """Decode a list of bytes to a list of strings.""" + return [x.decode() for x in list_of_bytes] + + +class TFModelWrapper(tf.Module): + def __init__( + self, + model, + ) -> None: + self.model = tf.saved_model.load(model) + self._call_lower = jax2tf.call_tf(self.model.call_lower) + self._call_lower_atomic_virial = jax2tf.call_tf( + self.model.call_lower_atomic_virial + ) + self.type_map = decode_list_of_bytes(self.model.get_type_map().numpy().tolist()) + self.rcut = self.model.get_rcut().numpy().item() + self.dim_fparam = self.model.get_dim_fparam().numpy().item() + self.dim_aparam = self.model.get_dim_aparam().numpy().item() + self.sel_type = self.model.get_sel_type().numpy().tolist() + self._is_aparam_nall = self.model.is_aparam_nall().numpy().item() + self._model_output_type = decode_list_of_bytes( + self.model.model_output_type().numpy().tolist() + ) + self._mixed_types = self.model.mixed_types().numpy().item() + if hasattr(self.model, "get_min_nbor_dist"): + self.min_nbor_dist = self.model.get_min_nbor_dist().numpy().item() + else: + self.min_nbor_dist = None + self.sel = self.model.get_sel().numpy().tolist() + self.model_def_script = self.model.get_model_def_script().numpy().decode() + + def __call__( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ) -> Any: + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return self.call(coord, atype, box, fparam, aparam, do_atomic_virial) + + def call( + self, + coord: jnp.ndarray, + atype: jnp.ndarray, + box: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + """Return model prediction. + + Parameters + ---------- + coord + The coordinates of the atoms. + shape: nf x (nloc x 3) + atype + The type of atoms. shape: nf x nloc + box + The simulation box. shape: nf x 9 + fparam + frame parameter. nf x ndf + aparam + atomic parameter. nf x nloc x nda + do_atomic_virial + If calculate the atomic virial. + + Returns + ------- + ret_dict + The result dict of type dict[str,jnp.ndarray]. + The keys are defined by the `ModelOutputDef`. + + """ + return model_call_from_call_lower( + call_lower=self.call_lower, + rcut=self.get_rcut(), + sel=self.get_sel(), + mixed_types=self.mixed_types(), + model_output_def=self.model_output_def(), + coord=coord, + atype=atype, + box=box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def model_output_def(self): + return ModelOutputDef( + FittingOutputDef([OUTPUT_DEFS[tt] for tt in self.model_output_type()]) + ) + + def call_lower( + self, + extended_coord: jnp.ndarray, + extended_atype: jnp.ndarray, + nlist: jnp.ndarray, + mapping: Optional[jnp.ndarray] = None, + fparam: Optional[jnp.ndarray] = None, + aparam: Optional[jnp.ndarray] = None, + do_atomic_virial: bool = False, + ): + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower + # Attempt to convert a value (None) with an unsupported type () to a Tensor. + if fparam is None: + fparam = jnp.empty( + (extended_coord.shape[0], self.get_dim_fparam()), dtype=jnp.float64 + ) + if aparam is None: + aparam = jnp.empty( + (extended_coord.shape[0], nlist.shape[1], self.get_dim_aparam()), + dtype=jnp.float64, + ) + return call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + + def get_type_map(self) -> list[str]: + """Get the type map.""" + return self.type_map + + def get_rcut(self): + """Get the cut-off radius.""" + return self.rcut + + def get_dim_fparam(self): + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.dim_fparam + + def get_dim_aparam(self): + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.dim_aparam + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.sel_type + + def is_aparam_nall(self) -> bool: + """Check whether the shape of atomic parameters is (nframes, nall, ndim). + + If False, the shape is (nframes, nloc, ndim). + """ + return self._is_aparam_nall + + def model_output_type(self) -> list[str]: + """Get the output type for the model.""" + return self._model_output_type + + def serialize(self) -> dict: + """Serialize the model. + + Returns + ------- + dict + The serialized data + """ + raise NotImplementedError("Not implemented") + + @classmethod + def deserialize(cls, data: dict) -> "TFModelWrapper": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BaseModel + The deserialized model + """ + raise NotImplementedError("Not implemented") + + def get_model_def_script(self) -> str: + """Get the model definition script.""" + return self.model_def_script + + def get_min_nbor_dist(self) -> Optional[float]: + """Get the minimum distance between two atoms.""" + return self.min_nbor_dist + + def get_nnei(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return self.get_nsel() + + def get_sel(self) -> list[int]: + return self.sel + + def get_nsel(self) -> int: + """Returns the total number of selected neighboring atoms in the cut-off radius.""" + return sum(self.sel) + + def mixed_types(self) -> bool: + return self._mixed_types + + @classmethod + def update_sel( + cls, + train_data: DeepmdDataSystem, + type_map: Optional[list[str]], + local_jdata: dict, + ) -> tuple[dict, Optional[float]]: + """Update the selection and perform neighbor statistics. + + Parameters + ---------- + train_data : DeepmdDataSystem + data used to do neighbor statictics + type_map : list[str], optional + The name of each type of atoms + local_jdata : dict + The local data refer to the current class + + Returns + ------- + dict + The updated local data + float + The minimum distance between two atoms + """ + raise NotImplementedError("Not implemented") + + @classmethod + def get_model(cls, model_params: dict) -> "TFModelWrapper": + """Get the model by the parameters. + + By default, all the parameters are directly passed to the constructor. + If not, override this method. + + Parameters + ---------- + model_params : dict + The model parameters + + Returns + ------- + BaseBaseModel + The model + """ + raise NotImplementedError("Not implemented") diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index ec2de3060e..6ab99a81f0 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -55,13 +55,13 @@ def deserialize_to_file(model_file: str, data: dict) -> None: def exported_whether_do_atomic_virial(do_atomic_virial): def call_lower_with_fixed_do_atomic_virial( - coord, atype, nlist, nlist_start, fparam, aparam + coord, atype, nlist, mapping, fparam, aparam ): return call_lower( coord, atype, nlist, - nlist_start, + mapping, fparam, aparam, do_atomic_virial=do_atomic_virial, @@ -107,8 +107,14 @@ def call_lower_with_fixed_do_atomic_virial( "sel": model.get_sel(), } save_dp_model(filename=model_file, model_dict=data) + elif model_file.endswith(".savedmodel"): + from deepmd.jax.jax2tf.serialization import ( + deserialize_to_file as deserialize_to_savedmodel, + ) + + return deserialize_to_savedmodel(model_file, data) else: - raise ValueError("JAX backend only supports converting .jax directory") + raise ValueError("Unsupported file extension") def serialize_from_file(model_file: str) -> dict: diff --git a/doc/backend.md b/doc/backend.md index cf99eea9cb..3fb70bee90 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -25,11 +25,12 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different ### JAX {{ jax_icon }} -- Model filename extension: `.xlo` +- Model filename extension: `.xlo`, `.savedmodel` - Checkpoint filename extension: `.jax` [JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required. Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions. +`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow. Currently, this backend is developed actively, and has no support for training and the C++ interface. ### DP {{ dpmodel_icon }} diff --git a/pyproject.toml b/pyproject.toml index 1faacb973c..802e920014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -444,6 +444,7 @@ select = [ [tool.uv.sources] mpich = { index = "mpi4py" } +openmpi = { index = "mpi4py" } [[tool.uv.index]] name = "mpi4py" diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 91cd391322..ca213da13c 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -23,6 +23,7 @@ from ...utils import ( CI, + DP_TEST_TF2_ONLY, TEST_DEVICE, ) @@ -72,6 +73,7 @@ def tearDown(self): shutil.rmtree(ii) @unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.") + @unittest.skipIf(DP_TEST_TF2_ONLY, "Conflict with TF2 eager mode.") def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for backend_name, suffix_idx in ( @@ -140,13 +142,21 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] - for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): + for backend_name, suffix_idx in ( + # unfortunately, jax2tf cannot work with tf v1 behaviors + ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), + ("pytorch", 0), + ("dpmodel", 0), + ("jax", 0), + ): backend = Backend.get_backend(backend_name)() if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) - self.save_data_to_model(prefix + backend.suffixes[0], reference_data) - deep_eval = DeepEval(prefix + backend.suffixes[0]) + self.save_data_to_model( + prefix + backend.suffixes[suffix_idx], reference_data + ) + deep_eval = DeepEval(prefix + backend.suffixes[suffix_idx]) if deep_eval.get_dim_fparam() > 0: fparam = np.ones((nframes, deep_eval.get_dim_fparam())) else: @@ -169,7 +179,7 @@ def test_deep_eval(self): self.atype, fparam=fparam, aparam=aparam, - do_atomic_virial=True, + atomic=True, ) rets.append(ret) for ret in rets[1:]: diff --git a/source/tests/utils.py b/source/tests/utils.py index bfb3d445af..a9bf0f11ea 100644 --- a/source/tests/utils.py +++ b/source/tests/utils.py @@ -8,3 +8,4 @@ # see https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/store-information-in-variables#default-environment-variables CI = os.environ.get("CI") == "true" +DP_TEST_TF2_ONLY = os.environ.get("DP_TEST_TF2_ONLY") == "1"