From 6e29d505f46a72bb957f9eccb5a1f2ba7c47ac8a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 16:34:13 -0400 Subject: [PATCH] add comments Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/serialization.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index fe5fdd1d22..102b588a9e 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -68,16 +68,16 @@ def call_lower_with_fixed_do_atomic_virial( ) return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))( - jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), - jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), - jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), - jax.ShapeDtypeStruct((nf, model.get_numb_fparam()), jnp.float64) + jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype + jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist + jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping + jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64) if model.get_dim_fparam() - else None, - jax.ShapeDtypeStruct((nf, nloc, model.get_numb_aparam()), jnp.float64) + else None, # fparam + jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64) if model.get_dim_aparam() - else None, + else None, # aparam ) exported = exported_whether_do_atomic_virial(do_atomic_virial=False)