Skip to content

Commit

Permalink
Revert "Update deepmd_utils/model_format/output_def.py"
Browse files Browse the repository at this point in the history
This reverts commit 72f7fc6.
  • Loading branch information
Han Wang committed Jan 12, 2024
1 parent 48293f4 commit 3f3393d
Showing 1 changed file with 40 additions and 6 deletions.
46 changes: 40 additions & 6 deletions deepmd_utils/model_format/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ def check_var(var, var_def):
raise ValueError(f"{var.shape[1:]} not matching def {var_def.shape}")


def check_output(cls, output_type):
def model_check_output(cls):
"""Check if the output of the Model is consistent with the definition.
Two methods are assumed to be provided by the Model:
1. Model.output_def that gives the output definition.
2. Model.forward that defines the forward path of the model.
"""

class wrapper(cls):
def __init__(
self,
Expand All @@ -38,8 +46,7 @@ def forward(
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
keys = self.md.keys_outp() if output_type == 'model' else self.md.keys()
for kk in keys:
for kk in self.md.keys_outp():
dd = self.md[kk]
check_var(ret[kk], dd)
if dd.reduciable:
Expand All @@ -50,13 +57,40 @@ def forward(
check_var(ret[dnr], self.md[dnr])
check_var(ret[dnc], self.md[dnc])
return ret

return wrapper

def model_check_output(cls):
return check_output(cls, 'model')

def fitting_check_output(cls):
return check_output(cls, 'fitting')
"""Check if the output of the Fitting is consistent with the definition.
Two methods are assumed to be provided by the Fitting:
1. Fitting.output_def that gives the output definition.
2. Fitting.forward defines the forward path of the fitting.
"""

class wrapper(cls):
def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.md = cls.output_def(self)

def forward(
self,
*args,
**kwargs,
):
ret = cls.forward(self, *args, **kwargs)
for kk in self.md.keys():
dd = self.md[kk]
check_var(ret[kk], dd)
return ret

return wrapper


class VariableDef:
Expand Down

0 comments on commit 3f3393d

Please sign in to comment.