Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Oct 31, 2024
1 parent 5b6320c commit 6e29d50
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ def call_lower_with_fixed_do_atomic_virial(
)

return jax_export.export(jax.jit(call_lower_with_fixed_do_atomic_virial))(
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64),
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32),
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64),
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64),
jax.ShapeDtypeStruct((nf, model.get_numb_fparam()), jnp.float64)
jax.ShapeDtypeStruct((nf, nloc + nghost, 3), jnp.float64), # extended_coord
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int32), # extended_atype
jax.ShapeDtypeStruct((nf, nloc, model.get_nnei()), jnp.int64), # nlist
jax.ShapeDtypeStruct((nf, nloc + nghost), jnp.int64), # mapping
jax.ShapeDtypeStruct((nf, model.get_dim_fparam()), jnp.float64)
if model.get_dim_fparam()
else None,
jax.ShapeDtypeStruct((nf, nloc, model.get_numb_aparam()), jnp.float64)
else None, # fparam
jax.ShapeDtypeStruct((nf, nloc, model.get_dim_aparam()), jnp.float64)
if model.get_dim_aparam()
else None,
else None, # aparam
)

exported = exported_whether_do_atomic_virial(do_atomic_virial=False)
Expand Down

0 comments on commit 6e29d50

Please sign in to comment.