Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor DeepEval #3213

Merged
merged 25 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e713910
refactor DeepEval
njzjz Feb 1, 2024
b428fc2
fix cycle import
njzjz Feb 1, 2024
eb0b7d2
fix errors
njzjz Feb 1, 2024
02b3a6d
fix PT tests
njzjz Feb 1, 2024
4fa8754
fix FrozenModel
njzjz Feb 1, 2024
6026ab7
add docs
njzjz Feb 1, 2024
28c5c92
Merge branch 'devel' into deepeval-refactor
njzjz Feb 2, 2024
c262ef6
rename .model_format to .dpmodel
njzjz Feb 2, 2024
4fe9511
rename DeepEvalBase to DeepEvalBackend as the meaning of base is unclear
njzjz Feb 2, 2024
e53b476
Merge remote-tracking branch 'origin/devel' into deepeval-refactor
njzjz Feb 2, 2024
5b16abf
ckean OutputVariableDef
njzjz Feb 2, 2024
51b5e7b
make atom_types to be int32
njzjz Feb 3, 2024
3516792
add docs
njzjz Feb 3, 2024
3bc7cb5
pass list[OutputVariableDef] instead of bool
njzjz Feb 4, 2024
acfd367
fix py38 compatibility
njzjz Feb 4, 2024
db6ac0d
Merge branch 'devel' into deepeval-refactor
njzjz Feb 5, 2024
be4e42f
use consistent output name for different backends
njzjz Feb 5, 2024
199147c
remove the commented out code
njzjz Feb 5, 2024
355c671
Merge branch 'devel' into deepeval-refactor
njzjz Feb 6, 2024
8686fa1
use OutputVariableCategory to check odef
njzjz Feb 6, 2024
c0a4251
fix typo
njzjz Feb 6, 2024
d7bf113
improve docs and type hints
njzjz Feb 8, 2024
acc85fb
Merge branch 'devel' into deepeval-refactor
njzjz Feb 8, 2024
295f4d7
change the OutputVariableDef argument from differentiable to r_differ…
njzjz Feb 8, 2024
6b104e3
fix the virial shape since DERV_C changed from (3,3) to (9,)
njzjz Feb 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions deepmd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@
__version__,
)


def DeepPotential(*args, **kwargs):
"""Factory function that forwards to DeepEval (for compatbility
and performance).

Parameters
----------
*args
positional arguments
**kwargs
keyword arguments

Returns
-------
DeepEval
potentials
"""
from deepmd.infer import (

Check warning on line 34 in deepmd/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/__init__.py#L34

Added line #L34 was not covered by tests
DeepPotential,
)

return DeepPotential(*args, **kwargs)

Check warning on line 38 in deepmd/__init__.py

View check run for this annotation

Codecov / codecov/patch

deepmd/__init__.py#L38

Added line #L38 was not covered by tests


__all__ = [
"__version__",
"DeepPotential",
]
1 change: 1 addition & 0 deletions deepmd/infer/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def detect_backend(filename: str) -> DPBackend:
filename : str
The model file name
"""
filename = str(filename).lower()
if filename.endswith(".pb"):
return DPBackend.TensorFlow
elif filename.endswith(".pth") or filename.endswith(".pt"):
Expand Down
28 changes: 28 additions & 0 deletions deepmd/infer/deep_dipole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.infer.deep_tensor import (
DeepTensor,
)
Fixed Show fixed Hide fixed
Comment on lines +2 to +4

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_tensor
begins an import cycle.


class DeepDipole(DeepTensor):
"""Deep dipole model.

Parameters
----------
model_file : Path
The name of the frozen model file.
*args : list
Positional arguments.
auto_batch_size : bool or int or AutoBatchSize, default: True
If True, automatic batch size will be used. If int, it will be used
as the initial batch size.
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional
The ASE neighbor list class to produce the neighbor list. If None, the
neighbor list will be built natively in the model.
**kwargs : dict
Keyword arguments.
"""

@property
def output_tensor_name(self) -> str:
return "dipole"
142 changes: 142 additions & 0 deletions deepmd/infer/deep_dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
Dict,
Optional,
Tuple,
)

import numpy as np

from deepmd.dpmodel.output_def import (
FittingOutputDef,
ModelOutputDef,
OutputVariableDef,
)

from .deep_eval import (
DeepEval,
)
Fixed Show fixed Hide fixed
Comment on lines +19 to +21

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.infer.deep_eval
begins an import cycle.


class DeepDOS(DeepEval):
"""Deep density of states model.

Parameters
----------
model_file : Path
The name of the frozen model file.
*args : list
Positional arguments.
auto_batch_size : bool or int or AutoBatchSize, default: True
If True, automatic batch size will be used. If int, it will be used
as the initial batch size.
neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional
The ASE neighbor list class to produce the neighbor list. If None, the
neighbor list will be built natively in the model.
**kwargs : dict
Keyword arguments.
"""

@property
def output_def(self) -> ModelOutputDef:
"""Get the output definition of this model."""
return ModelOutputDef(
FittingOutputDef(
[
OutputVariableDef(
"dos",
shape=[-1],
reduciable=True,
atomic=True,
),
]
)
)

def eval(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
atomic: bool = False,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
mixed_type: bool = False,
**kwargs: Dict[str, Any],
) -> Tuple[np.ndarray, ...]:
Fixed Show fixed Hide fixed
"""Evaluate energy, force, and virial. If atomic is True,
also return atomic energy and atomic virial.

Parameters
----------
coords : np.ndarray
The coordinates of the atoms, in shape (nframes, natoms, 3).
cells : np.ndarray
The cell vectors of the system, in shape (nframes, 9). If the system
is not periodic, set it to None.
atom_types : List[int]
The types of the atoms. If mixed_type is False, the shape is (natoms,);
otherwise, the shape is (nframes, natoms).
atomic : bool, optional
Whether to return atomic energy and atomic virial, by default False.
fparam : np.ndarray, optional
The frame parameters, by default None.
aparam : np.ndarray, optional
The atomic parameters, by default None.
mixed_type : bool, optional
Whether the atom_types is mixed type, by default False.
**kwargs : Dict[str, Any]
Keyword arguments.

Returns
-------
energy
The energy of the system, in shape (nframes,).
force
The force of the system, in shape (nframes, natoms, 3).
virial
The virial of the system, in shape (nframes, 9).
atomic_energy
The atomic energy of the system, in shape (nframes, natoms). Only returned
when atomic is True.
atomic_virial
The atomic virial of the system, in shape (nframes, natoms, 9). Only returned
when atomic is True.
"""
(
coords,
cells,
atom_types,
fparam,
aparam,
nframes,
natoms,
) = self._standard_input(coords, cells, atom_types, fparam, aparam, mixed_type)
results = self.deep_eval.eval(
coords,
cells,
atom_types,
atomic,
fparam=fparam,
aparam=aparam,
**kwargs,
)
# energy = results["dos_redu"].reshape(nframes, self.get_numb_dos())
atomic_energy = results["dos"].reshape(nframes, natoms, self.get_numb_dos())
# not same as dos_redu... why?
energy = np.sum(atomic_energy, axis=1)

if atomic:
return (
energy,
atomic_energy,
)
else:
return (energy,)

Check warning on line 136 in deepmd/infer/deep_dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/infer/deep_dos.py#L136

Added line #L136 was not covered by tests

def get_numb_dos(self) -> int:
return self.deep_eval.get_numb_dos()


__all__ = ["DeepDOS"]
Loading
Loading