Skip to content

Commit

Permalink
initial commit with se_e2_a supported only
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 11, 2023
1 parent 43fc073 commit 27045f7
Show file tree
Hide file tree
Showing 10 changed files with 575 additions and 24 deletions.
28 changes: 28 additions & 0 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,31 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict):
# call subprocess
cls = cls.get_class_by_input(local_jdata)
return cls.update_sel(global_jdata, local_jdata)

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
if cls is Descriptor:
return Descriptor.get_class_by_input(data).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 529 in deepmd/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/descriptor.py#L527-L529

Added lines #L527 - L529 were not covered by tests

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
raise NotImplementedError("Not implemented in class %s" % self.__name__)

Check warning on line 539 in deepmd/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/descriptor.py#L539

Added line #L539 was not covered by tests
55 changes: 55 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(
self.seed_shift = embedding_net_rand_seed_shift(self.filter_neuron)
self.trainable = trainable
self.compress_activation_fn = get_activation_func(activation_function)
self.activation_function_name = activation_function
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
Expand Down Expand Up @@ -1334,3 +1335,57 @@ def explicit_ntypes(self) -> bool:
if self.stripped_type_embedding:
return True
return False

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
descriptor = cls(**data)
descriptor.davg = data["@variables"]["davg"]
descriptor.dstd = data["@variables"]["dstd"]
descriptor.embedding_net_variables = data["@variables"]
descriptor.original_sel = data["@variables"]["original_sel"]
return descriptor

Check warning on line 1358 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1353-L1358

Added lines #L1353 - L1358 were not covered by tests

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
return {

Check warning on line 1368 in deepmd/descriptor/se_a.py

View check run for this annotation

Codecov / codecov/patch

deepmd/descriptor/se_a.py#L1368

Added line #L1368 was not covered by tests
"type": "se_e2_a",
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
"neuron": self.filter_neuron,
"axis_neuron": self.n_axis_neuron,
"resnet_dt": self.filter_resnet_dt,
"trainable": self.trainable,
"seed": self.seed,
"type_one_side": self.type_one_side,
"exclude_types": list(self.exclude_types),
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function_name,
"precision": self.filter_precision.name,
"uniform_seed": self.uniform_seed,
"stripped_type_embedding": self.stripped_type_embedding,
"@variables": {
**self.embedding_net_variables,
"davg": self.davg,
"dstd": self.dstd,
"original_sel": self.original_sel,
},
}
11 changes: 11 additions & 0 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
convert_pbtxt_to_pb,
convert_to_21,
)
from deepmd.utils.convert_dp import (
convert_dp_to_pb,
convert_pb_to_dp,
)


def convert(
Expand All @@ -23,6 +27,11 @@ def convert(
convert_pb_to_pbtxt(input_model, output_model)
else:
raise RuntimeError("input model is already pbtxt")
elif output_model.endswith(".dp"):
if input_model.endswith(".pb"):
convert_pb_to_dp(input_model, output_model)

Check warning on line 32 in deepmd/entrypoints/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert.py#L31-L32

Added lines #L31 - L32 were not covered by tests
else:
raise RuntimeError("Unsupported format")

Check warning on line 34 in deepmd/entrypoints/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert.py#L34

Added line #L34 was not covered by tests
else:
if FROM == "auto":
convert_to_21(input_model, output_model)
Expand All @@ -39,5 +48,7 @@ def convert(
convert_20_to_21(input_model, output_model)
elif FROM == "pbtxt":
convert_pbtxt_to_pb(input_model, output_model)
elif FROM == "dp":
convert_dp_to_pb(input_model, output_model)

Check warning on line 52 in deepmd/entrypoints/convert.py

View check run for this annotation

Codecov / codecov/patch

deepmd/entrypoints/convert.py#L51-L52

Added lines #L51 - L52 were not covered by tests
else:
raise RuntimeError("unsupported model version " + FROM)
72 changes: 71 additions & 1 deletion deepmd/fit/ener.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
TYPE_CHECKING,
List,
Optional,
)
Expand Down Expand Up @@ -53,6 +54,11 @@
Spin,
)

if TYPE_CHECKING:
from deepmd.descriptor import (

Check warning on line 58 in deepmd/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/ener.py#L58

Added line #L58 was not covered by tests
Descriptor,
)

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -130,7 +136,7 @@ class EnerFitting(Fitting):

def __init__(
self,
descrpt: tf.Tensor,
descrpt: "Descriptor",
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
numb_fparam: int = 0,
Expand Down Expand Up @@ -176,6 +182,7 @@ def __init__(
self.ntypes_spin = self.spin.get_ntypes_spin() if self.spin is not None else 0
self.seed_shift = one_layer_rand_seed_shift()
self.tot_ener_zero = tot_ener_zero
self.activation_function_name = activation_function
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
self.trainable = trainable
Expand Down Expand Up @@ -916,3 +923,66 @@ def get_loss(self, loss: dict, lr) -> Loss:
return EnerSpinLoss(**loss, use_spin=self.spin.use_spin)
else:
raise RuntimeError("unknown loss type")

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
fitting = cls(**data)
fitting.fitting_net_variables = data["@variables"]
fitting.bias_atom_e = fitting.fitting_net_variables.pop("bias_atom_e")
if fitting.numb_fparam > 0:
fitting.fparam_avg = fitting.fitting_net_variables.pop("fparam_avg")
fitting.fparam_inv_std = fitting.fitting_net_variables.pop("fparam_inv_std")
if fitting.numb_aparam > 0:
fitting.aparam_avg = fitting.fitting_net_variables.pop("aparam_avg")
fitting.aparam_inv_std = fitting.fitting_net_variables.pop("aparam_inv_std")
return fitting

Check warning on line 950 in deepmd/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/ener.py#L941-L950

Added lines #L941 - L950 were not covered by tests

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {

Check warning on line 960 in deepmd/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/ener.py#L960

Added line #L960 was not covered by tests
"type": "ener",
"neuron": self.n_neuron,
"resnet_dt": self.resnet_dt,
"numb_fparam": self.numb_fparam,
"numb_aparam": self.numb_aparam,
"rcond": self.rcond,
"tot_ener_zero": self.tot_ener_zero,
"trainable": self.trainable,
"seed": self.seed,
"atom_ener": self.atom_ener,
"activation_function": self.activation_function_name,
"precision": self.fitting_precision.name,
"uniform_seed": self.uniform_seed,
"layer_name": self.layer_name,
"use_aparam_as_mask": self.use_aparam_as_mask,
"@variables": {
**self.fitting_net_variables,
"bias_atom_e": self.bias_atom_e,
},
}

if self.numb_fparam > 0:
data["@variables"]["fparam_avg"] = self.fparam_avg
data["@variables"]["fparam_inv_std"] = self.fparam_inv_std
if self.numb_aparam > 0:
data["@variables"]["aparam_avg"] = self.aparam_avg
data["@variables"]["aparam_inv_std"] = self.aparam_inv_std
return data

Check warning on line 988 in deepmd/fit/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/ener.py#L982-L988

Added lines #L982 - L988 were not covered by tests
48 changes: 40 additions & 8 deletions deepmd/fit/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,20 @@ class SomeFitting(Fitting):
"""
return Fitting.__plugins.register(key)

@classmethod
def get_class_by_input(cls, input: dict):
try:
fitting_type = input["type"]
except KeyError:
raise KeyError("the type of fitting should be set by `type`")

Check warning on line 51 in deepmd/fit/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/fitting.py#L50-L51

Added lines #L50 - L51 were not covered by tests
if fitting_type in Fitting.__plugins.plugins:
return Fitting.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown descriptor type: " + fitting_type)

Check warning on line 55 in deepmd/fit/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/fitting.py#L55

Added line #L55 was not covered by tests

def __new__(cls, *args, **kwargs):
if cls is Fitting:
try:
fitting_type = kwargs["type"]
except KeyError:
raise KeyError("the type of fitting should be set by `type`")
if fitting_type in Fitting.__plugins.plugins:
cls = Fitting.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown descriptor type: " + fitting_type)
cls = Fitting.get_class_by_input(kwargs)
return super().__new__(cls)

@property
Expand Down Expand Up @@ -102,3 +106,31 @@ def get_loss(self, loss: dict, lr) -> Loss:
Loss
the loss function
"""

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
if cls is Fitting:
return Fitting.get_class_by_input(data).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 126 in deepmd/fit/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/fitting.py#L124-L126

Added lines #L124 - L126 were not covered by tests

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
raise NotImplementedError("Not implemented in class %s" % self.__name__)

Check warning on line 136 in deepmd/fit/fitting.py

View check run for this annotation

Codecov / codecov/patch

deepmd/fit/fitting.py#L136

Added line #L136 was not covered by tests
62 changes: 62 additions & 0 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@

import numpy as np

from deepmd.descriptor.descriptor import (
Descriptor,
)
from deepmd.env import (
MODEL_VERSION,
global_cvt_2_ener_float,
op_module,
tf,
)
from deepmd.fit.fitting import (
Fitting,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -512,3 +518,59 @@ def change_energy_bias(
bias_shift,
self.data_bias_nsample,
)

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
data = data.copy()
data["descriptor"] = Descriptor.deserialize(data["descriptor"])
fitting_data = data["fitting_net"].copy()
fitting_data["descrpt"] = data["descriptor"]
data["fitting_net"] = Fitting.deserialize(fitting_data)
if data.get("type_embedding") is not None:
data["type_embedding"] = TypeEmbedNet.deserialize(data["type_embedding"])
return cls(**data)

Check warning on line 543 in deepmd/model/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/ener.py#L536-L543

Added lines #L536 - L543 were not covered by tests

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
data = {

Check warning on line 553 in deepmd/model/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/ener.py#L553

Added line #L553 was not covered by tests
"type": "standard",
"descriptor": self.descrpt.serialize(),
"fitting_net": self.fitting.serialize(),
"type_embedding": self.typeebd.serialize()
if self.typeebd is not None
else None,
"type_map": self.get_type_map(),
"data_stat_nbatch": self.data_stat_nbatch,
"data_stat_protect": self.data_stat_protect,
"use_srtab": self.srtab_name,
"spin": self.spin.serialize() if self.spin is not None else None,
"data_bias_nsample": self.data_bias_nsample,
}
if self.srtab_name is not None:
data.merge(

Check warning on line 568 in deepmd/model/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/ener.py#L567-L568

Added lines #L567 - L568 were not covered by tests
{
"srtab_add_bias": self.srtab_add_bias,
"smin_alpha": self.smin_alpha,
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
}
)
return data

Check warning on line 576 in deepmd/model/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/model/ener.py#L576

Added line #L576 was not covered by tests
Loading

0 comments on commit 27045f7

Please sign in to comment.