diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 5410f518d1..79c129334a 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -10,6 +10,7 @@ ) from deepmd.pt.model.task.ener import ( EnergyFittingNet, + EnergyFittingNetDirect, ) from deepmd.pt.model.task.polarizability import ( PolarFittingNet, @@ -36,7 +37,9 @@ def __new__(cls, descriptor, fitting, *args, **kwargs): # according to the fitting network to decide the type of the model if cls is DPModel: # map fitting to model - if isinstance(fitting, EnergyFittingNet): + if isinstance(fitting, EnergyFittingNet) or isinstance( + fitting, EnergyFittingNetDirect + ): cls = EnergyModel elif isinstance(fitting, DipoleFittingNet): cls = DipoleModel