From 94d20543b5c3a1d83eb4636524a9125de343f2f8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Nov 2024 17:04:46 -0500 Subject: [PATCH] nopbc Signed-off-by: Jinzhe Zeng --- deepmd/jax/infer/deep_eval.py | 6 +++++ deepmd/jax/model/hlo.py | 18 +++++++++++--- deepmd/jax/utils/serialization.py | 35 ++++++++++++++++++++++----- source/tests/consistent/io/test_io.py | 27 +++++++++++++++++++++ 4 files changed, 77 insertions(+), 9 deletions(-) diff --git a/deepmd/jax/infer/deep_eval.py b/deepmd/jax/infer/deep_eval.py index fc526a502e..b9d1974c27 100644 --- a/deepmd/jax/infer/deep_eval.py +++ b/deepmd/jax/infer/deep_eval.py @@ -97,6 +97,12 @@ def __init__( stablehlo_atomic_virial=model_data["@variables"][ "stablehlo_atomic_virial" ].tobytes(), + stablehlo_no_ghost=model_data["@variables"][ + "stablehlo_no_ghost" + ].tobytes(), + stablehlo_atomic_virial_no_ghost=model_data["@variables"][ + "stablehlo_atomic_virial_no_ghost" + ].tobytes(), model_def_script=model_data["model_def_script"], **model_data["constants"], ) diff --git a/deepmd/jax/model/hlo.py b/deepmd/jax/model/hlo.py index 2946f8bec7..4d59957456 100644 --- a/deepmd/jax/model/hlo.py +++ b/deepmd/jax/model/hlo.py @@ -46,6 +46,8 @@ def __init__( self, stablehlo, stablehlo_atomic_virial, + stablehlo_no_ghost, + stablehlo_atomic_virial_no_ghost, model_def_script, type_map, rcut, @@ -62,6 +64,10 @@ def __init__( self._call_lower_atomic_virial = jax_export.deserialize( stablehlo_atomic_virial ).call + self._call_lower_no_ghost = jax_export.deserialize(stablehlo_no_ghost).call + self._call_lower_atomic_virial_no_ghost = jax_export.deserialize( + stablehlo_atomic_virial_no_ghost + ).call self.stablehlo = stablehlo self.type_map = type_map self.rcut = rcut @@ -174,10 +180,16 @@ def call_lower( aparam: Optional[jnp.ndarray] = None, do_atomic_virial: bool = False, ): - if do_atomic_virial: - call_lower = self._call_lower_atomic_virial + if extended_coord.shape[1] > nlist.shape[1]: + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial + else: + call_lower = self._call_lower else: - call_lower = self._call_lower + if do_atomic_virial: + call_lower = self._call_lower_atomic_virial_no_ghost + else: + call_lower = self._call_lower_no_ghost return call_lower( extended_coord, extended_atype, diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index 6ab99a81f0..1ed26f2d40 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -53,7 +53,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost") - def exported_whether_do_atomic_virial(do_atomic_virial): + def exported_whether_do_atomic_virial( + do_atomic_virial: bool, has_ghost_atoms: bool + ): def call_lower_with_fixed_do_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): @@ -67,13 +69,18 @@ def call_lower_with_fixed_do_atomic_virial( do_atomic_virial=do_atomic_virial, ) + if has_ghost_atoms: + nghost_ = nghost + else: + nghost_ = 0 + return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( jax.ShapeDtypeStruct( - (nf, nloc + nghost, 3), jnp.float64 + (nf, nloc + nghost_, 3), jnp.float64 ), # extended_coord - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int32), # extended_atype jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, nloc + nghost_), jnp.int64), # mapping jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() else None, # fparam @@ -82,18 +89,34 @@ def call_lower_with_fixed_do_atomic_virial( else None, # aparam ) - exported = exported_whether_do_atomic_virial(do_atomic_virial=False) + exported = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=True + ) exported_atomic_virial = exported_whether_do_atomic_virial( - do_atomic_virial=True + do_atomic_virial=True, has_ghost_atoms=True ) serialized: bytearray = exported.serialize() serialized_atomic_virial = exported_atomic_virial.serialize() + + exported_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=False, has_ghost_atoms=False + ) + exported_atomic_virial_no_ghost = exported_whether_do_atomic_virial( + do_atomic_virial=True, has_ghost_atoms=False + ) + serialized_no_ghost: bytearray = exported_no_ghost.serialize() + serialized_atomic_virial_no_ghost = exported_atomic_virial_no_ghost.serialize() + data = data.copy() data.setdefault("@variables", {}) data["@variables"]["stablehlo"] = np.void(serialized) data["@variables"]["stablehlo_atomic_virial"] = np.void( serialized_atomic_virial ) + data["@variables"]["stablehlo_no_ghost"] = np.void(serialized_no_ghost) + data["@variables"]["stablehlo_atomic_virial_no_ghost"] = np.void( + serialized_atomic_virial_no_ghost + ) data["constants"] = { "type_map": model.get_type_map(), "rcut": model.get_rcut(), diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index ca213da13c..8eb26e7ac3 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -142,6 +142,7 @@ def test_deep_eval(self): nframes = self.atype.shape[0] prefix = "test_consistent_io_" + self.__class__.__name__.lower() rets = [] + rets_nopbc = [] for backend_name, suffix_idx in ( # unfortunately, jax2tf cannot work with tf v1 behaviors ("jax", 2) if DP_TEST_TF2_ONLY else ("tensorflow", 0), @@ -182,6 +183,23 @@ def test_deep_eval(self): atomic=True, ) rets.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + ) + rets_nopbc.append(ret) + ret = deep_eval.eval( + self.coords, + None, + self.atype, + fparam=fparam, + aparam=aparam, + atomic=True, + ) + rets_nopbc.append(ret) for ret in rets[1:]: for vv1, vv2 in zip(rets[0], ret): if np.isnan(vv2).all(): @@ -189,6 +207,15 @@ def test_deep_eval(self): continue np.testing.assert_allclose(vv1, vv2, rtol=1e-12, atol=1e-12) + for idx, ret in enumerate(rets_nopbc[1:]): + for vv1, vv2 in zip(rets_nopbc[0], ret): + if np.isnan(vv2).all(): + # expect all nan if not supported + continue + np.testing.assert_allclose( + vv1, vv2, rtol=1e-12, atol=1e-12, err_msg=f"backend {idx+1}" + ) + class TestDeepPot(unittest.TestCase, IOTest): def setUp(self):