From 9071e73fbd5470dd58dd1e13f54bc4c00fadf3c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Nov 2024 06:55:51 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/model/model.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index 43eeb13898..1d18b70e8e 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( DPAtomicModel, ) @@ -17,23 +19,30 @@ from deepmd.dpmodel.model.base_model import ( BaseModel, ) +from deepmd.dpmodel.model.dipole_model import ( + DipoleModel, +) +from deepmd.dpmodel.model.dos_model import ( + DOSModel, +) from deepmd.dpmodel.model.dp_zbl_model import ( DPZBLModel, ) from deepmd.dpmodel.model.ener_model import ( EnergyModel, ) +from deepmd.dpmodel.model.polar_model import ( + PolarModel, +) +from deepmd.dpmodel.model.property_model import ( + PropertyModel, +) from deepmd.dpmodel.model.spin_model import ( SpinModel, ) from deepmd.utils.spin import ( Spin, ) -import copy -from deepmd.dpmodel.model.dos_model import DOSModel -from deepmd.dpmodel.model.property_model import PropertyModel -from deepmd.dpmodel.model.dipole_model import DipoleModel -from deepmd.dpmodel.model.polar_model import PolarModel def _get_standard_model_components(data, ntypes): @@ -57,7 +66,8 @@ def _get_standard_model_components(data, ntypes): fitting_net["return_energy"] = True fitting = BaseFitting(**fitting_net) return descriptor, fitting, fitting_net["type"] - + + def get_standard_model(data: dict) -> EnergyModel: """Get a EnergyModel from a dictionary. @@ -72,12 +82,9 @@ def get_standard_model(data: dict) -> EnergyModel: ) data = copy.deepcopy(data) ntypes = len(data["type_map"]) - descriptor, fitting, fitting_net_type = _get_standard_model_components( - data, ntypes - ) + descriptor, fitting, fitting_net_type = _get_standard_model_components(data, ntypes) atom_exclude_types = data.get("atom_exclude_types", []) pair_exclude_types = data.get("pair_exclude_types", []) - if fitting_net_type == "dipole": modelcls = DipoleModel