Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 8, 2024
1 parent cd38bcf commit 24546c7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
5 changes: 4 additions & 1 deletion deepmd_utils/model_format/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
"float16": np.float16,
"float32": np.float32,
"float64": np.float64,
"default": np.float64,
"half": np.float16,
"single": np.float32,
"double": np.float64,
}
DEFAULT_PRECISION = "float64"


class NativeOP(ABC):
Expand Down
23 changes: 20 additions & 3 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
__version__ = "unknown"

from .common import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(
idt: Optional[np.ndarray] = None,
activation_function: Optional[str] = None,
resnet: bool = False,
precision: str = "default",
precision: str = DEFAULT_PRECISION,
) -> None:
prec = PRECISION_DICT[precision.lower()]
self.precision = precision
Expand Down Expand Up @@ -190,13 +191,29 @@ def deserialize(cls, data: dict) -> "NativeLayer":
data : dict
The dict to deserialize from.
"""
precision = data.get("precision", DEFAULT_PRECISION)
# assertion "float64" == "double" would fail
assert (
PRECISION_DICT[data["@variables"]["w"].dtype.name]
is PRECISION_DICT[precision]
)
if data["@variables"].get("b", None) is not None:
assert (
PRECISION_DICT[data["@variables"]["b"].dtype.name]
is PRECISION_DICT[precision]
)
if data["@variables"].get("idt", None) is not None:
assert (
PRECISION_DICT[data["@variables"]["idt"].dtype.name]
is PRECISION_DICT[precision]
)
return cls(
w=data["@variables"]["w"],
b=data["@variables"].get("b", None),
idt=data["@variables"].get("idt", None),
activation_function=data["activation_function"],
resnet=data.get("resnet", False),
precision=data.get("precision", "default"),
precision=precision,
)

def __setitem__(self, key, value):
Expand Down Expand Up @@ -341,7 +358,7 @@ def __init__(
neuron: List[int] = [24, 48, 96],
activation_function: str = "tanh",
resnet_dt: bool = False,
precision: str = "default",
precision: str = DEFAULT_PRECISION,
):
layers = []
i_in = in_dim
Expand Down

0 comments on commit 24546c7

Please sign in to comment.