Skip to content

Commit

Permalink
check type consistency at construction
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 8, 2024
1 parent 24546c7 commit 9a21db9
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__(
self.idt = idt.astype(prec) if idt is not None else None
self.activation_function = activation_function
self.resnet = resnet
self.check_type_consistency()

def serialize(self) -> dict:
"""Serialize the layer to a dict.
Expand Down Expand Up @@ -192,21 +193,6 @@ def deserialize(cls, data: dict) -> "NativeLayer":
The dict to deserialize from.
"""
precision = data.get("precision", DEFAULT_PRECISION)
# assertion "float64" == "double" would fail
assert (
PRECISION_DICT[data["@variables"]["w"].dtype.name]
is PRECISION_DICT[precision]
)
if data["@variables"].get("b", None) is not None:
assert (
PRECISION_DICT[data["@variables"]["b"].dtype.name]
is PRECISION_DICT[precision]
)
if data["@variables"].get("idt", None) is not None:
assert (
PRECISION_DICT[data["@variables"]["idt"].dtype.name]
is PRECISION_DICT[precision]
)
return cls(
w=data["@variables"]["w"],
b=data["@variables"].get("b", None),
Expand All @@ -216,6 +202,18 @@ def deserialize(cls, data: dict) -> "NativeLayer":
precision=precision,
)

def check_type_consistency(self):
precision = self.precision

def check_var(var):
if var is not None:
# assertion "float64" == "double" would fail
assert PRECISION_DICT[var.dtype.name] is PRECISION_DICT[precision]

check_var(self.w)
check_var(self.b)
check_var(self.idt)

def __setitem__(self, key, value):
if key in ("w", "matrix"):
self.w = value
Expand Down

0 comments on commit 9a21db9

Please sign in to comment.