diff --git a/deepmd/pt/infer/inference.py b/deepmd/pt/infer/inference.py index b3ee8ccdef..b3d120cbc4 100644 --- a/deepmd/pt/infer/inference.py +++ b/deepmd/pt/infer/inference.py @@ -62,4 +62,4 @@ def __init__( self.wrapper = ModelWrapper(self.model) # inference only if JIT: self.wrapper = torch.jit.script(self.wrapper) - self.wrapper.load_state_dict(state_dict, strict=False) + self.wrapper.load_state_dict(state_dict) diff --git a/deepmd/pt/utils/serialization.py b/deepmd/pt/utils/serialization.py index c988b99d77..5d3b02482a 100644 --- a/deepmd/pt/utils/serialization.py +++ b/deepmd/pt/utils/serialization.py @@ -34,7 +34,7 @@ def serialize_from_file(model_file: str) -> dict: saved_model = torch.jit.load(model_file, map_location="cpu") model_def_script = json.loads(saved_model.model_def_script) model = get_model(model_def_script) - model.load_state_dict(saved_model.state_dict(), strict=False) + model.load_state_dict(saved_model.state_dict()) elif model_file.endswith(".pt"): state_dict = torch.load(model_file, map_location="cpu", weights_only=True) if "model" in state_dict: