From 24546c76c84a3afae16ed9a0b6d5309dfdaf5d70 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 8 Jan 2024 22:55:22 +0800 Subject: [PATCH] fix comments --- deepmd_utils/model_format/common.py | 5 ++++- deepmd_utils/model_format/network.py | 23 ++++++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/deepmd_utils/model_format/common.py b/deepmd_utils/model_format/common.py index 2ca39321eb..82beb969c2 100644 --- a/deepmd_utils/model_format/common.py +++ b/deepmd_utils/model_format/common.py @@ -9,8 +9,11 @@ "float16": np.float16, "float32": np.float32, "float64": np.float64, - "default": np.float64, + "half": np.float16, + "single": np.float32, + "double": np.float64, } +DEFAULT_PRECISION = "float64" class NativeOP(ABC): diff --git a/deepmd_utils/model_format/network.py b/deepmd_utils/model_format/network.py index 59d0bc1fdd..b7e5c7c288 100644 --- a/deepmd_utils/model_format/network.py +++ b/deepmd_utils/model_format/network.py @@ -18,6 +18,7 @@ __version__ = "unknown" from .common import ( + DEFAULT_PRECISION, PRECISION_DICT, NativeOP, ) @@ -150,7 +151,7 @@ def __init__( idt: Optional[np.ndarray] = None, activation_function: Optional[str] = None, resnet: bool = False, - precision: str = "default", + precision: str = DEFAULT_PRECISION, ) -> None: prec = PRECISION_DICT[precision.lower()] self.precision = precision @@ -190,13 +191,29 @@ def deserialize(cls, data: dict) -> "NativeLayer": data : dict The dict to deserialize from. """ + precision = data.get("precision", DEFAULT_PRECISION) + # assertion "float64" == "double" would fail + assert ( + PRECISION_DICT[data["@variables"]["w"].dtype.name] + is PRECISION_DICT[precision] + ) + if data["@variables"].get("b", None) is not None: + assert ( + PRECISION_DICT[data["@variables"]["b"].dtype.name] + is PRECISION_DICT[precision] + ) + if data["@variables"].get("idt", None) is not None: + assert ( + PRECISION_DICT[data["@variables"]["idt"].dtype.name] + is PRECISION_DICT[precision] + ) return cls( w=data["@variables"]["w"], b=data["@variables"].get("b", None), idt=data["@variables"].get("idt", None), activation_function=data["activation_function"], resnet=data.get("resnet", False), - precision=data.get("precision", "default"), + precision=precision, ) def __setitem__(self, key, value): @@ -341,7 +358,7 @@ def __init__( neuron: List[int] = [24, 48, 96], activation_function: str = "tanh", resnet_dt: bool = False, - precision: str = "default", + precision: str = DEFAULT_PRECISION, ): layers = [] i_in = in_dim