Skip to content

Commit

Permalink
add __call__ for NativeOP. model check assumes __call__ as the forwar…
Browse files Browse the repository at this point in the history
…d method
  • Loading branch information
Han Wang committed Jan 12, 2024
1 parent 828df66 commit 5f0bd99
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .common import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from .env_mat import (
EnvMat,
Expand Down Expand Up @@ -34,6 +35,7 @@
"NativeLayer",
"NativeNet",
"NetworkCollection",
"NativeOP",
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
Expand Down
4 changes: 4 additions & 0 deletions deepmd_utils/model_format/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class NativeOP(ABC):
def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
raise NotImplementedError

def __call__(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
return self.call(*args, **kwargs)
12 changes: 6 additions & 6 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def model_check_output(cls):
Two methods are assumed to be provided by the Model:
1. Model.output_def that gives the output definition.
2. Model.forward that defines the forward path of the model.
2. Model.__call__ that defines the forward path of the model.
"""

Expand All @@ -40,12 +40,12 @@ def __init__(
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
def __call__(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
ret = cls.__call__(self, *args, **kwargs)
for kk in self.md.keys_outp():
dd = self.md[kk]
check_var(ret[kk], dd)
Expand All @@ -66,7 +66,7 @@ def fitting_check_output(cls):
Two methods are assumed to be provided by the Fitting:
1. Fitting.output_def that gives the output definition.
2. Fitting.forward defines the forward path of the fitting.
2. Fitting.__call__ defines the forward path of the fitting.
"""

Expand All @@ -79,12 +79,12 @@ def __init__(
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
def __call__(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
ret = cls.__call__(self, *args, **kwargs)
for kk in self.md.keys():
dd = self.md[kk]
check_var(ret[kk], dd)
Expand Down
48 changes: 26 additions & 22 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd_utils.model_format import (
FittingOutputDef,
ModelOutputDef,
NativeOP,
OutputVariableDef,
fitting_check_output,
model_check_output,
Expand Down Expand Up @@ -91,14 +92,14 @@ def test_model_decorator(self):
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
Expand All @@ -107,21 +108,24 @@ def forward(self):
}

ff = Foo()
ff.forward()
ff()

def test_model_decorator_keyerror(self):
nf = 2
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def __init__(self):
super().__init__()

def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
Expand All @@ -130,15 +134,15 @@ def forward(self):

ff = Foo()
with self.assertRaises(KeyError) as context:
ff.forward()
ff()
self.assertIn("energy_derv_r", context.exception)

def test_model_decorator_shapeerror(self):
nf = 2
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def __init__(
self,
shape_rd=[nf, 1],
Expand All @@ -152,7 +156,7 @@ def output_def(self):
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros(self.shape_rd),
Expand All @@ -161,56 +165,56 @@ def forward(self):
}

ff = Foo()
ff.forward()
ff()
# shape of reduced energy
with self.assertRaises(ValueError) as context:
ff = Foo(shape_rd=[nf, nloc, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_rd=[nf, 2])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
# shape of dr
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1, 3, 3])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1, 4])
ff.forward()
ff()
self.assertIn("not matching", context.exception)

def test_fitting_decorator(self):
nf = 2
nloc = 3

@fitting_check_output
class Foo:
class Foo(NativeOP):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return FittingOutputDef(defs)

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
}

ff = Foo()
ff.forward()
ff()

def test_fitting_decorator_shapeerror(self):
nf = 2
nloc = 3

@fitting_check_output
class Foo:
class Foo(NativeOP):
def __init__(
self,
shape=[nf, nloc, 1],
Expand All @@ -223,19 +227,19 @@ def output_def(self):
]
return FittingOutputDef(defs)

def forward(self):
def call(self):
return {
"energy": np.zeros(self.shape),
}

ff = Foo()
ff.forward()
ff()
# shape of reduced energy
with self.assertRaises(ValueError) as context:
ff = Foo(shape=[nf, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape=[nf, nloc, 2])
ff.forward()
ff()
self.assertIn("not matching", context.exception)

0 comments on commit 5f0bd99

Please sign in to comment.