Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 5, 2024
1 parent 5fbbd61 commit 0d20b8d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
8 changes: 6 additions & 2 deletions deepmd_utils/model_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,16 @@ def call(self, x: np.ndarray) -> np.ndarray:
if self.activation_function == "tanh":
fn = np.tanh
elif self.activation_function.lower() == "none":

def fn(x):
return x
else:
raise NotImplementedError(self.activation_function)
y = np.matmul(x, self.w) + self.b \
if self.b is not None else np.matmul(x, self.w)
y = (
np.matmul(x, self.w) + self.b
if self.b is not None
else np.matmul(x, self.w)
)
y = fn(y)
if self.idt is not None:
y *= self.idt
Expand Down
20 changes: 10 additions & 10 deletions source/tests/test_model_format_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
import os
import unittest
from copy import (
deepcopy,
)
import itertools

import numpy as np

from deepmd_utils.model_format import (
Expand All @@ -14,21 +15,21 @@
save_dp_model,
)


class TestNativeLayer(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((2, 3), 3.0)
self.b = np.full((3,), 4.0)
self.idt = np.full((3,), 5.0)

def test_serialize_deserize(self):
for ww,bb,idt,activation_function,resnet in \
itertools.product(
[self.w], [self.b, None], [self.idt, None],
["tanh", "none"], [True, False] ):
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp = np.arange(self.w.shape[0])
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))
for ww, bb, idt, activation_function, resnet in itertools.product(
[self.w], [self.b, None], [self.idt, None], ["tanh", "none"], [True, False]
):
nl0 = NativeLayer(ww, bb, idt, activation_function, resnet)
nl1 = NativeLayer.deserialize(nl0.serialize())
inp = np.arange(self.w.shape[0])
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))


class TestNativeNet(unittest.TestCase):
Expand Down Expand Up @@ -84,7 +85,6 @@ def test_deserialize(self):
np.testing.assert_array_equal(network[1]["resnet"], True)



class TestDPModel(unittest.TestCase):
def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
Expand Down

0 comments on commit 0d20b8d

Please sign in to comment.