diff --git a/deepmd_utils/model_format/__init__.py b/deepmd_utils/model_format/__init__.py index 356eaaf4fa..0d6972e9cf 100644 --- a/deepmd_utils/model_format/__init__.py +++ b/deepmd_utils/model_format/__init__.py @@ -2,6 +2,7 @@ from .common import ( DEFAULT_PRECISION, PRECISION_DICT, + NativeOP, ) from .env_mat import ( EnvMat, @@ -34,6 +35,7 @@ "NativeLayer", "NativeNet", "NetworkCollection", + "NativeOP", "load_dp_model", "save_dp_model", "traverse_model_dict", diff --git a/deepmd_utils/model_format/common.py b/deepmd_utils/model_format/common.py index 82beb969c2..d032e5d5df 100644 --- a/deepmd_utils/model_format/common.py +++ b/deepmd_utils/model_format/common.py @@ -22,3 +22,7 @@ class NativeOP(ABC): def call(self, *args, **kwargs): """Forward pass in NumPy implementation.""" raise NotImplementedError + + def __call__(self, *args, **kwargs): + """Forward pass in NumPy implementation.""" + return self.call(*args, **kwargs) diff --git a/deepmd_utils/model_format/output_def.py b/deepmd_utils/model_format/output_def.py index f4fcdce3ca..7feb24a145 100644 --- a/deepmd_utils/model_format/output_def.py +++ b/deepmd_utils/model_format/output_def.py @@ -27,7 +27,7 @@ def model_check_output(cls): Two methods are assumed to be provided by the Model: 1. Model.output_def that gives the output definition. - 2. Model.forward that defines the forward path of the model. + 2. Model.__call__ that defines the forward path of the model. """ @@ -40,12 +40,12 @@ def __init__( super().__init__(*args, **kwargs) self.md = cls.output_def(self) - def forward( + def __call__( self, *args, **kwargs, ): - ret = cls.forward(self, *args, **kwargs) + ret = cls.__call__(self, *args, **kwargs) for kk in self.md.keys_outp(): dd = self.md[kk] check_var(ret[kk], dd) @@ -66,7 +66,7 @@ def fitting_check_output(cls): Two methods are assumed to be provided by the Fitting: 1. Fitting.output_def that gives the output definition. - 2. Fitting.forward defines the forward path of the fitting. + 2. Fitting.__call__ defines the forward path of the fitting. """ @@ -79,12 +79,12 @@ def __init__( super().__init__(*args, **kwargs) self.md = cls.output_def(self) - def forward( + def __call__( self, *args, **kwargs, ): - ret = cls.forward(self, *args, **kwargs) + ret = cls.__call__(self, *args, **kwargs) for kk in self.md.keys(): dd = self.md[kk] check_var(ret[kk], dd) diff --git a/source/tests/test_output_def.py b/source/tests/test_output_def.py index 7f5404ee31..e0c56784da 100644 --- a/source/tests/test_output_def.py +++ b/source/tests/test_output_def.py @@ -6,6 +6,7 @@ from deepmd_utils.model_format import ( FittingOutputDef, ModelOutputDef, + NativeOP, OutputVariableDef, fitting_check_output, model_check_output, @@ -91,14 +92,14 @@ def test_model_decorator(self): nloc = 3 @model_check_output - class Foo: + class Foo(NativeOP): def output_def(self): defs = [ OutputVariableDef("energy", [1], True, True), ] return ModelOutputDef(FittingOutputDef(defs)) - def forward(self): + def call(self): return { "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros([nf, 1]), @@ -107,21 +108,24 @@ def forward(self): } ff = Foo() - ff.forward() + ff() def test_model_decorator_keyerror(self): nf = 2 nloc = 3 @model_check_output - class Foo: + class Foo(NativeOP): + def __init__(self): + super().__init__() + def output_def(self): defs = [ OutputVariableDef("energy", [1], True, True), ] return ModelOutputDef(FittingOutputDef(defs)) - def forward(self): + def call(self): return { "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros([nf, 1]), @@ -130,7 +134,7 @@ def forward(self): ff = Foo() with self.assertRaises(KeyError) as context: - ff.forward() + ff() self.assertIn("energy_derv_r", context.exception) def test_model_decorator_shapeerror(self): @@ -138,7 +142,7 @@ def test_model_decorator_shapeerror(self): nloc = 3 @model_check_output - class Foo: + class Foo(NativeOP): def __init__( self, shape_rd=[nf, 1], @@ -152,7 +156,7 @@ def output_def(self): ] return ModelOutputDef(FittingOutputDef(defs)) - def forward(self): + def call(self): return { "energy": np.zeros([nf, nloc, 1]), "energy_redu": np.zeros(self.shape_rd), @@ -161,28 +165,28 @@ def forward(self): } ff = Foo() - ff.forward() + ff() # shape of reduced energy with self.assertRaises(ValueError) as context: ff = Foo(shape_rd=[nf, nloc, 1]) - ff.forward() + ff() self.assertIn("not matching", context.exception) with self.assertRaises(ValueError) as context: ff = Foo(shape_rd=[nf, 2]) - ff.forward() + ff() self.assertIn("not matching", context.exception) # shape of dr with self.assertRaises(ValueError) as context: ff = Foo(shape_dr=[nf, nloc, 1]) - ff.forward() + ff() self.assertIn("not matching", context.exception) with self.assertRaises(ValueError) as context: ff = Foo(shape_dr=[nf, nloc, 1, 3, 3]) - ff.forward() + ff() self.assertIn("not matching", context.exception) with self.assertRaises(ValueError) as context: ff = Foo(shape_dr=[nf, nloc, 1, 4]) - ff.forward() + ff() self.assertIn("not matching", context.exception) def test_fitting_decorator(self): @@ -190,27 +194,27 @@ def test_fitting_decorator(self): nloc = 3 @fitting_check_output - class Foo: + class Foo(NativeOP): def output_def(self): defs = [ OutputVariableDef("energy", [1], True, True), ] return FittingOutputDef(defs) - def forward(self): + def call(self): return { "energy": np.zeros([nf, nloc, 1]), } ff = Foo() - ff.forward() + ff() def test_fitting_decorator_shapeerror(self): nf = 2 nloc = 3 @fitting_check_output - class Foo: + class Foo(NativeOP): def __init__( self, shape=[nf, nloc, 1], @@ -223,19 +227,19 @@ def output_def(self): ] return FittingOutputDef(defs) - def forward(self): + def call(self): return { "energy": np.zeros(self.shape), } ff = Foo() - ff.forward() + ff() # shape of reduced energy with self.assertRaises(ValueError) as context: ff = Foo(shape=[nf, 1]) - ff.forward() + ff() self.assertIn("not matching", context.exception) with self.assertRaises(ValueError) as context: ff = Foo(shape=[nf, nloc, 2]) - ff.forward() + ff() self.assertIn("not matching", context.exception)