diff --git a/deepmd_utils/model_format.py b/deepmd_utils/model_format.py index 83e6ac11fc..131be93121 100644 --- a/deepmd_utils/model_format.py +++ b/deepmd_utils/model_format.py @@ -262,7 +262,7 @@ def fn(x): if self.resnet and self.w.shape[1] == self.w.shape[0]: y += x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += np.concatenate([x, x], axis=1) + y += np.concatenate([x, x], axis=-1) return y diff --git a/source/tests/test_model_format_utils.py b/source/tests/test_model_format_utils.py index 2fef4e1922..af8c4361c8 100644 --- a/source/tests/test_model_format_utils.py +++ b/source/tests/test_model_format_utils.py @@ -17,18 +17,24 @@ class TestNativeLayer(unittest.TestCase): - def setUp(self) -> None: - self.w = np.full((2, 3), 3.0) - self.b = np.full((3,), 4.0) - self.idt = np.full((3,), 5.0) - def test_serialize_deserize(self): - for ww, bb, idt, activation_function, resnet in itertools.product( - [self.w], [self.b, None], [self.idt, None], ["tanh", "none"], [True, False] + for (ni, no), bias, ut, activation_function, resnet, ashp in itertools.product( + [(5, 5), (5, 10), (5, 9), (9, 5)], + [True, False], + [True, False], + ["tanh", "none"], + [True, False], + [None, [4], [3, 2]], ): + ww = np.full((ni, no), 3.0) + bb = np.full((no,), 4.0) if bias else None + idt = np.full((no,), 5.0) if ut else None nl0 = NativeLayer(ww, bb, idt, activation_function, resnet) nl1 = NativeLayer.deserialize(nl0.serialize()) - inp = np.arange(self.w.shape[0]) + inp_shap = [ww.shape[0]] + if ashp is not None: + inp_shap = ashp + inp_shap + inp = np.arange(np.prod(inp_shap)).reshape(inp_shap) np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))