diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index a539adb292..fbf2c6e21f 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -108,7 +108,9 @@ def model_call_from_call_lower( nloc, rcut, sel, - distinguish_types=not mixed_types, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, ) extended_coord = extended_coord.reshape(nframes, -1, 3) model_predict_lower = call_lower( diff --git a/deepmd/jax/jax2tf/make_model.py b/deepmd/jax/jax2tf/make_model.py index d21fc998b5..29ed131f8e 100644 --- a/deepmd/jax/jax2tf/make_model.py +++ b/deepmd/jax/jax2tf/make_model.py @@ -90,7 +90,9 @@ def model_call_from_call_lower( nloc, rcut, sel, - distinguish_types=not mixed_types, + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + distinguish_types=False, ) extended_coord = extended_coord.reshape(nframes, -1, 3) model_predict_lower = call_lower( diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 6bb5f6b8e9..83abf9ee4a 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -175,7 +175,9 @@ def forward_common( atype, self.get_rcut(), self.get_sel(), - mixed_types=self.mixed_types(), + # types will be distinguished in the lower interface, + # so it doesn't need to be distinguished here + mixed_types=True, box=bb, ) model_predict_lower = self.forward_common_lower(