Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: model check assumes __call__ as the forward method #3136

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd_utils/model_format/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .common import (
DEFAULT_PRECISION,
PRECISION_DICT,
NativeOP,
)
from .env_mat import (
EnvMat,
Expand Down Expand Up @@ -34,6 +35,7 @@
"NativeLayer",
"NativeNet",
"NetworkCollection",
"NativeOP",
"load_dp_model",
"save_dp_model",
"traverse_model_dict",
Expand Down
4 changes: 4 additions & 0 deletions deepmd_utils/model_format/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class NativeOP(ABC):
def call(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
raise NotImplementedError

def __call__(self, *args, **kwargs):
"""Forward pass in NumPy implementation."""
return self.call(*args, **kwargs)
12 changes: 6 additions & 6 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def model_check_output(cls):

Two methods are assumed to be provided by the Model:
1. Model.output_def that gives the output definition.
2. Model.forward that defines the forward path of the model.
2. Model.__call__ that defines the forward path of the model.

"""

Expand All @@ -40,12 +40,12 @@ def __init__(
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
def __call__(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
ret = cls.__call__(self, *args, **kwargs)
for kk in self.md.keys_outp():
dd = self.md[kk]
check_var(ret[kk], dd)
Expand All @@ -66,7 +66,7 @@ def fitting_check_output(cls):

Two methods are assumed to be provided by the Fitting:
1. Fitting.output_def that gives the output definition.
2. Fitting.forward defines the forward path of the fitting.
2. Fitting.__call__ defines the forward path of the fitting.

"""

Expand All @@ -79,12 +79,12 @@ def __init__(
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
def __call__(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
ret = cls.__call__(self, *args, **kwargs)
for kk in self.md.keys():
dd = self.md[kk]
check_var(ret[kk], dd)
Expand Down
48 changes: 26 additions & 22 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd_utils.model_format import (
FittingOutputDef,
ModelOutputDef,
NativeOP,
OutputVariableDef,
fitting_check_output,
model_check_output,
Expand Down Expand Up @@ -91,14 +92,14 @@ def test_model_decorator(self):
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
Expand All @@ -107,21 +108,24 @@ def forward(self):
}

ff = Foo()
ff.forward()
ff()

def test_model_decorator_keyerror(self):
nf = 2
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def __init__(self):
super().__init__()

def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros([nf, 1]),
Expand All @@ -130,15 +134,15 @@ def forward(self):

ff = Foo()
with self.assertRaises(KeyError) as context:
ff.forward()
ff()
self.assertIn("energy_derv_r", context.exception)

def test_model_decorator_shapeerror(self):
nf = 2
nloc = 3

@model_check_output
class Foo:
class Foo(NativeOP):
def __init__(
self,
shape_rd=[nf, 1],
Expand All @@ -152,7 +156,7 @@ def output_def(self):
]
return ModelOutputDef(FittingOutputDef(defs))

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
"energy_redu": np.zeros(self.shape_rd),
Expand All @@ -161,56 +165,56 @@ def forward(self):
}

ff = Foo()
ff.forward()
ff()
# shape of reduced energy
with self.assertRaises(ValueError) as context:
ff = Foo(shape_rd=[nf, nloc, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_rd=[nf, 2])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
# shape of dr
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1, 3, 3])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape_dr=[nf, nloc, 1, 4])
ff.forward()
ff()
self.assertIn("not matching", context.exception)

def test_fitting_decorator(self):
nf = 2
nloc = 3

@fitting_check_output
class Foo:
class Foo(NativeOP):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
return FittingOutputDef(defs)

def forward(self):
def call(self):
return {
"energy": np.zeros([nf, nloc, 1]),
}

ff = Foo()
ff.forward()
ff()

def test_fitting_decorator_shapeerror(self):
nf = 2
nloc = 3

@fitting_check_output
class Foo:
class Foo(NativeOP):
def __init__(
self,
shape=[nf, nloc, 1],
Expand All @@ -223,19 +227,19 @@ def output_def(self):
]
return FittingOutputDef(defs)

def forward(self):
def call(self):
return {
"energy": np.zeros(self.shape),
}

ff = Foo()
ff.forward()
ff()
# shape of reduced energy
with self.assertRaises(ValueError) as context:
ff = Foo(shape=[nf, 1])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
with self.assertRaises(ValueError) as context:
ff = Foo(shape=[nf, nloc, 2])
ff.forward()
ff()
self.assertIn("not matching", context.exception)
Loading