Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 31, 2024
1 parent 4cb1cbc commit ba7147a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None:
nf, nloc, nghost = jax_export.symbolic_shape("nf, nloc, nghost")

def exported_whether_do_atomic_virial(do_atomic_virial):
def call_lower_with_fixed_do_atmic_virial(
def call_lower_with_fixed_do_atomic_virial(
coord, atype, nlist, nlist_start, fparam, aparam
):
return call_lower(
Expand All @@ -67,7 +67,7 @@ def call_lower_with_fixed_do_atmic_virial(
do_atomic_virial=do_atomic_virial,
)

return jax_export.export(jax.jit(call_lower_with_fixed_do_atmic_virial))(
return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64),
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32),
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64),
Expand Down

0 comments on commit ba7147a

Please sign in to comment.