diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index b7e5c7c288..98c35636fa 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -160,6 +160,7 @@ def __init__( self.idt = idt.astype(prec) if idt is not None else None self.activation_function = activation_function self.resnet = resnet + self.check_type_consistency() def serialize(self) -> dict: """Serialize the layer to a dict. @@ -192,21 +193,6 @@ def deserialize(cls, data: dict) -> "NativeLayer": The dict to deserialize from. """ precision = data.get("precision", DEFAULT_PRECISION) - # assertion "float64" == "double" would fail - assert ( - PRECISION_DICT[data["@variables"]["w"].dtype.name] - is PRECISION_DICT[precision] - ) - if data["@variables"].get("b", None) is not None: - assert ( - PRECISION_DICT[data["@variables"]["b"].dtype.name] - is PRECISION_DICT[precision] - ) - if data["@variables"].get("idt", None) is not None: - assert ( - PRECISION_DICT[data["@variables"]["idt"].dtype.name] - is PRECISION_DICT[precision] - ) return cls( w=data["@variables"]["w"], b=data["@variables"].get("b", None), @@ -216,6 +202,18 @@ def deserialize(cls, data: dict) -> "NativeLayer": precision=precision, ) + def check_type_consistency(self): + precision = self.precision + + def check_var(var): + if var is not None: + # assertion "float64" == "double" would fail + assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision] + + check_var(self.w) + check_var(self.b) + check_var(self.idt) + def __setitem__(self, key, value): if key in ("w", "matrix"): self.w = value