Skip to content

Commit

Permalink
refactorize NativeLayer, interface does not rely on the platform (dee…
Browse files Browse the repository at this point in the history
…pmodeling#3138)

- add parameter shape consistency check for layer 
- add input-output shape consistency check for net

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Jan 13, 2024
1 parent 308f97e commit 15117a0
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 53 deletions.
94 changes: 68 additions & 26 deletions deepmd_utils/model_format/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
See issue #2982 for more information.
"""
import copy
import itertools
import json
from typing import (
Expand Down Expand Up @@ -150,23 +151,26 @@ class NativeLayer(NativeOP):

def __init__(
self,
w: Optional[np.ndarray] = None,
b: Optional[np.ndarray] = None,
idt: Optional[np.ndarray] = None,
num_in,
num_out,
bias: bool = True,
use_timestep: bool = False,
activation_function: Optional[str] = None,
resnet: bool = False,
precision: str = DEFAULT_PRECISION,
) -> None:
prec = PRECISION_DICT[precision.lower()]
self.precision = precision
self.w = w.astype(prec) if w is not None else None
self.b = b.astype(prec) if b is not None else None
self.idt = idt.astype(prec) if idt is not None else None
rng = np.random.default_rng()
self.w = rng.normal(size=(num_in, num_out)).astype(prec)
self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None
self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None
self.activation_function = (
activation_function if activation_function is not None else "none"
)
self.resnet = resnet
self.check_type_consistency()
self.check_shape_consistency()

def serialize(self) -> dict:
"""Serialize the layer to a dict.
Expand All @@ -179,10 +183,11 @@ def serialize(self) -> dict:
data = {
"w": self.w,
"b": self.b,
"idt": self.idt,
}
if self.idt is not None:
data["idt"] = self.idt
return {
"bias": self.b is not None,
"use_timestep": self.idt is not None,
"activation_function": self.activation_function,
"resnet": self.resnet,
"precision": self.precision,
Expand All @@ -198,15 +203,34 @@ def deserialize(cls, data: dict) -> "NativeLayer":
data : dict
The dict to deserialize from.
"""
precision = data.get("precision", DEFAULT_PRECISION)
return cls(
w=data["@variables"]["w"],
b=data["@variables"].get("b", None),
idt=data["@variables"].get("idt", None),
activation_function=data["activation_function"],
resnet=data.get("resnet", False),
precision=precision,
data = copy.deepcopy(data)
variables = data.pop("@variables")
assert variables["w"] is not None and len(variables["w"].shape) == 2
num_in, num_out = variables["w"].shape
obj = cls(
num_in,
num_out,
**data,
)
obj.w, obj.b, obj.idt = (
variables["w"],
variables.get("b", None),
variables.get("idt", None),
)
obj.check_shape_consistency()
return obj

def check_shape_consistency(self):
if self.b is not None and self.w.shape[1] != self.b.shape[0]:
raise ValueError(
f"dim 1 of w {self.w.shape[1]} is not equal to shape "
f"of b {self.b.shape[0]}",
)
if self.idt is not None and self.w.shape[1] != self.idt.shape[0]:
raise ValueError(
f"dim 1 of w {self.w.shape[1]} is not equal to shape "
f"of idt {self.idt.shape[0]}",
)

def check_type_consistency(self):
precision = self.precision
Expand Down Expand Up @@ -252,6 +276,14 @@ def __getitem__(self, key):
else:
raise KeyError(key)

@property
def dim_in(self) -> int:
return self.w.shape[0]

@property
def dim_out(self) -> int:
return self.w.shape[1]

def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.
Expand Down Expand Up @@ -303,6 +335,7 @@ def __init__(self, layers: Optional[List[dict]] = None) -> None:
if layers is None:
layers = []
self.layers = [NativeLayer.deserialize(layer) for layer in layers]
self.check_shape_consistency()

def serialize(self) -> dict:
"""Serialize the network to a dict.
Expand All @@ -327,16 +360,21 @@ def deserialize(cls, data: dict) -> "NativeNet":

def __getitem__(self, key):
assert isinstance(key, int)
if len(self.layers) <= key:
self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1))
return self.layers[key]

def __setitem__(self, key, value):
assert isinstance(key, int)
if len(self.layers) <= key:
self.layers.extend([NativeLayer()] * (key - len(self.layers) + 1))
self.layers[key] = value

def check_shape_consistency(self):
for ii in range(len(self.layers) - 1):
if self.layers[ii].dim_out != self.layers[ii + 1].dim_in:
raise ValueError(
f"the dim of layer {ii} output {self.layers[ii].dim_out} ",
f"does not match the dim of layer {ii+1} ",
f"output {self.layers[ii].dim_out}",
)

def call(self, x: np.ndarray) -> np.ndarray:
"""Forward pass.
Expand Down Expand Up @@ -389,9 +427,10 @@ def __init__(
i_ot = ii
layers.append(
NativeLayer(
rng.normal(size=(i_in, i_ot)),
b=rng.normal(size=(i_ot)),
idt=rng.normal(size=(i_ot)) if resnet_dt else None,
i_in,
i_ot,
bias=True,
use_timestep=resnet_dt,
activation_function=activation_function,
resnet=True,
precision=precision,
Expand Down Expand Up @@ -431,6 +470,7 @@ def deserialize(cls, data: dict) -> "EmbeddingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
layers = data.pop("layers")
obj = cls(**data)
super(EmbeddingNet, obj).__init__(layers)
Expand Down Expand Up @@ -481,9 +521,10 @@ def __init__(
i_in, i_ot = neuron[-1], out_dim
self.layers.append(
NativeLayer(
rng.normal(size=(i_in, i_ot)),
b=rng.normal(size=(i_ot)) if bias_out else None,
idt=None,
i_in,
i_ot,
bias=bias_out,
use_timestep=False,
activation_function=None,
resnet=False,
precision=precision,
Expand Down Expand Up @@ -520,6 +561,7 @@ def deserialize(cls, data: dict) -> "FittingNet":
data : dict
The dict to deserialize from.
"""
data = copy.deepcopy(data)
layers = data.pop("layers")
obj = cls(**data)
NativeNet.__init__(obj, layers)
Expand Down
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
except ImportError:
__version__ = "unknown"

import copy
from typing import (
Any,
List,
Expand Down Expand Up @@ -270,6 +271,7 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeA":
data = copy.deepcopy(data)
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
env_mat = data.pop("env_mat")
Expand Down
111 changes: 84 additions & 27 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,74 @@ def test_serialize_deserize(self):
[None, [4], [3, 2]],
["float32", "float64", "single", "double"],
):
ww = np.full((ni, no), 3.0)
bb = np.full((no,), 4.0) if bias else None
idt = np.full((no,), 5.0) if ut else None
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet, prec)
nl0 = NativeLayer(
ni,
no,
bias=bias,
use_timestep=ut,
activation_function=activation_function,
resnet=resnet,
precision=prec,
)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp_shap = [ww.shape[0]]
inp_shap = [ni]
if ashp is not None:
inp_shap = ashp + inp_shap
inp = np.arange(np.prod(inp_shap)).reshape(inp_shap)
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))

def test_shape_error(self):
self.w0 = np.full((2, 3), 3.0)
self.b0 = np.full((2,), 4.0)
self.b1 = np.full((3,), 4.0)
self.idt0 = np.full((2,), 4.0)
with self.assertRaises(ValueError) as context:
network = NativeLayer.deserialize(
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w0, "b": self.b0},
}
)
assert "not equalt to shape of b" in context.exception
with self.assertRaises(ValueError) as context:
network = NativeLayer.deserialize(
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w0, "b": self.b1, "idt": self.idt0},
}
)
assert "not equalt to shape of idt" in context.exception


class TestNativeNet(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((2, 3), 3.0)
self.b = np.full((3,), 4.0)
self.idt = np.full((3,), 5.0)
self.w0 = np.full((2, 3), 3.0)
self.b0 = np.full((3,), 4.0)
self.w1 = np.full((3, 4), 3.0)
self.b1 = np.full((4,), 4.0)

def test_serialize(self):
network = NativeNet()
network[1]["w"] = self.w
network[1]["b"] = self.b
network[0]["w"] = self.w
network[0]["b"] = self.b
network = NativeNet(
[
NativeLayer(2, 3).serialize(),
NativeLayer(3, 4).serialize(),
]
)
network[1]["w"] = self.w1
network[1]["b"] = self.b1
network[0]["w"] = self.w0
network[0]["b"] = self.b0
network[1]["activation_function"] = "tanh"
network[0]["activation_function"] = "tanh"
network[1]["resnet"] = True
network[0]["resnet"] = True
jdata = network.serialize()
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b)
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["w"], self.w0)
np.testing.assert_array_equal(jdata["layers"][0]["@variables"]["b"], self.b0)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["w"], self.w1)
np.testing.assert_array_equal(jdata["layers"][1]["@variables"]["b"], self.b1)
np.testing.assert_array_equal(jdata["layers"][0]["activation_function"], "tanh")
np.testing.assert_array_equal(jdata["layers"][1]["activation_function"], "tanh")
np.testing.assert_array_equal(jdata["layers"][0]["resnet"], True)
Expand All @@ -80,25 +115,45 @@ def test_deserialize(self):
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
"@variables": {"w": self.w0, "b": self.b0},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w, "b": self.b},
"@variables": {"w": self.w1, "b": self.b1},
},
],
}
)
np.testing.assert_array_equal(network[0]["w"], self.w)
np.testing.assert_array_equal(network[0]["b"], self.b)
np.testing.assert_array_equal(network[1]["w"], self.w)
np.testing.assert_array_equal(network[1]["b"], self.b)
np.testing.assert_array_equal(network[0]["w"], self.w0)
np.testing.assert_array_equal(network[0]["b"], self.b0)
np.testing.assert_array_equal(network[1]["w"], self.w1)
np.testing.assert_array_equal(network[1]["b"], self.b1)
np.testing.assert_array_equal(network[0]["activation_function"], "tanh")
np.testing.assert_array_equal(network[1]["activation_function"], "tanh")
np.testing.assert_array_equal(network[0]["resnet"], True)
np.testing.assert_array_equal(network[1]["resnet"], True)

def test_shape_error(self):
with self.assertRaises(ValueError) as context:
network = NativeNet.deserialize(
{
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w0, "b": self.b0},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": self.w0, "b": self.b0},
},
],
}
)
assert "does not match the dim of layer" in context.exception


class TestEmbeddingNet(unittest.TestCase):
def test_embedding_net(self):
Expand Down Expand Up @@ -146,19 +201,21 @@ def test_fitting_net(self):

class TestNetworkCollection(unittest.TestCase):
def setUp(self) -> None:
w = np.full((2, 3), 3.0)
b = np.full((3,), 4.0)
w0 = np.full((2, 3), 3.0)
b0 = np.full((3,), 4.0)
w1 = np.full((3, 4), 3.0)
b1 = np.full((4,), 4.0)
self.network = {
"layers": [
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
"@variables": {"w": w0, "b": b0},
},
{
"activation_function": "tanh",
"resnet": True,
"@variables": {"w": w, "b": b},
"@variables": {"w": w1, "b": b1},
},
],
}
Expand Down

0 comments on commit 15117a0

Please sign in to comment.