Skip to content

Commit

Permalink
fix native layer concat bug. (#3112)
Browse files Browse the repository at this point in the history
add UT testing all input and output shapes

---------

Co-authored-by: Han Wang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 5, 2024
1 parent c4b7baa commit 7b3c3c0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deepmd_utils/model_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def fn(x):
if self.resnet and self.w.shape[1] == self.w.shape[0]:
y += x
elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]:
y += np.concatenate([x, x], axis=1)
y += np.concatenate([x, x], axis=-1)
return y


Expand Down
22 changes: 14 additions & 8 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,24 @@


class TestNativeLayer(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((2, 3), 3.0)
self.b = np.full((3,), 4.0)
self.idt = np.full((3,), 5.0)

def test_serialize_deserize(self):
for ww, bb, idt, activation_function, resnet in itertools.product(
[self.w], [self.b, None], [self.idt, None], ["tanh", "none"], [True, False]
for (ni, no), bias, ut, activation_function, resnet, ashp in itertools.product(
[(5, 5), (5, 10), (5, 9), (9, 5)],
[True, False],
[True, False],
["tanh", "none"],
[True, False],
[None, [4], [3, 2]],
):
ww = np.full((ni, no), 3.0)
bb = np.full((no,), 4.0) if bias else None
idt = np.full((no,), 5.0) if ut else None
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp = np.arange(self.w.shape[0])
inp_shap = [ww.shape[0]]
if ashp is not None:
inp_shap = ashp + inp_shap
inp = np.arange(np.prod(inp_shap)).reshape(inp_shap)
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))


Expand Down

0 comments on commit 7b3c3c0

Please sign in to comment.