Skip to content

Commit

Permalink
Improve the model inherit
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 20, 2024
1 parent cda2e60 commit b740b2d
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 228 deletions.
6 changes: 0 additions & 6 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,11 @@
from .pairtab_atomic_model import (
PairTabAtomicModel,
)
from .wrapper_atomic_model import (
DPSpinWrapperAtomicModel,
WrapperAtomicModel,
)

__all__ = [
"BaseAtomicModel",
"DPAtomicModel",
"PairTabAtomicModel",
"LinearAtomicModel",
"DPZBLLinearAtomicModel",
"WrapperAtomicModel",
"DPSpinWrapperAtomicModel",
]
180 changes: 0 additions & 180 deletions deepmd/pt/model/atomic_model/wrapper_atomic_model.py

This file was deleted.

11 changes: 6 additions & 5 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
DPModel,
)
from .dp_spin_model import (
SpinEnergyModel,
SpinModel,
)
from .dp_zbl_model import (
Expand Down Expand Up @@ -136,18 +137,17 @@ def get_spin_model(model_params):
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
fitting_net["ntypes"] = descriptor.get_ntypes()
fitting_net["distinguish_types"] = descriptor.distinguish_types()
fitting_net["mixed_types"] = descriptor.mixed_types()
fitting_net["embedding_width"] = descriptor.get_dim_out()
fitting_net["dim_descrpt"] = descriptor.get_dim_out()
grad_force = "direct" not in fitting_net["type"]
if not grad_force:
fitting_net["out_dim"] = descriptor.get_dim_emb()
if "ener" in fitting_net["type"]:
fitting_net["return_energy"] = True
fitting = Fitting(**fitting_net)
backbone_model = DPAtomicModel(
descriptor, fitting, type_map=model_params["type_map"]
)
return SpinModel(backbone_model=backbone_model, spin=spin)
backbone_model = DPModel(descriptor, fitting, type_map=model_params["type_map"])
return SpinEnergyModel(backbone_model=backbone_model, spin=spin)


def get_model(model_params):
Expand All @@ -166,6 +166,7 @@ def get_model(model_params):
"DPModel",
"EnergyModel",
"SpinModel",
"SpinEnergyModel",
"DPZBLModel",
"make_model",
"make_hessian_model",
Expand Down
Loading

0 comments on commit b740b2d

Please sign in to comment.