Skip to content

Commit

Permalink
output_def should not be class method
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Jan 10, 2024
1 parent 60cf941 commit 9b23a57
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
20 changes: 18 additions & 2 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ def check_var(var, var_def):


def model_check_output(cls):
"""Check if the output of the Model is consistent with the definition.
Two methods are assumed to be provided by the Model:
1. Model.output_def that gives the output definition.
2. Model.forward that defines the forward path of the model.
"""

class wrapper(cls):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def()
self.md = cls.output_def(self)

def forward(
self,
Expand All @@ -54,14 +62,22 @@ def forward(


def fitting_check_output(cls):
"""Check if the output of the Fitting is consistent with the definition.
Two methods are assumed to be provided by the Fitting:
1. Fitting.output_def that gives the output definition.
2. Fitting.forward defines the forward path of the fitting.
"""

class wrapper(cls):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def()
self.md = cls.output_def(self)

def forward(
self,
Expand Down
13 changes: 4 additions & 9 deletions source/tests/test_output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def test_model_decorator(self):

@model_check_output
class Foo:
@classmethod
def output_def(cls):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
Expand All @@ -116,8 +115,7 @@ def test_model_decorator_keyerror(self):

@model_check_output
class Foo:
@classmethod
def output_def(cls):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
Expand Down Expand Up @@ -148,7 +146,6 @@ def __init__(
):
self.shape_rd, self.shape_dr = shape_rd, shape_dr

@classmethod
def output_def(cls):
defs = [
OutputVariableDef("energy", [1], True, True),
Expand Down Expand Up @@ -194,8 +191,7 @@ def test_fitting_decorator(self):

@fitting_check_output
class Foo:
@classmethod
def output_def(cls):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
Expand All @@ -221,8 +217,7 @@ def __init__(
):
self.shape = shape

@classmethod
def output_def(cls):
def output_def(self):
defs = [
OutputVariableDef("energy", [1], True, True),
]
Expand Down

0 comments on commit 9b23a57

Please sign in to comment.