Skip to content

Commit

Permalink
Merge pull request #2 from xusuyong/add_water_tensor_dipole
Browse files Browse the repository at this point in the history
[Paddle Backend] Add water tensor dipole example
  • Loading branch information
HydrogenSulfate authored Dec 14, 2023
2 parents 17223e7 + 0cb4fde commit 382503b
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 232 deletions.
216 changes: 138 additions & 78 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
import logging
from typing import List
from typing import Optional

import numpy as np
from paddle import nn

from deepmd.common import add_data_requirement
from deepmd.common import cast_precision
from deepmd.common import get_activation_func
from deepmd.common import get_precision
from deepmd.env import GLOBAL_PD_FLOAT_PRECISION
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import global_cvt_2_pd_float
from deepmd.env import global_cvt_2_tf_float
from deepmd.env import paddle
from deepmd.env import tf
from deepmd.fit.fitting import Fitting

# from deepmd.infer import DeepPotential
from deepmd.nvnmd.fit.ener import one_layer_nvnmd
from deepmd.nvnmd.utils.config import nvnmd_cfg
from deepmd.utils.errors import GraphWithoutTensorError
from deepmd.utils.graph import get_fitting_net_variables_from_graph_def
from deepmd.utils.graph import get_tensor_by_name_from_graph
from deepmd.utils.network import OneLayer as OneLayer_deepmd
from deepmd.utils.network import one_layer
from deepmd.utils.network import one_layer as one_layer_deepmd
from deepmd.utils.network import one_layer_rand_seed_shift
from deepmd.utils.spin import Spin


@Fitting.register("dipole")
class DipoleFittingSeA(Fitting):
# @Fitting.register("dipole")
class DipoleFittingSeA(nn.Layer):
r"""Fit the atomic dipole with descriptor se_a.
Parameters
Expand All @@ -40,7 +56,7 @@ class DipoleFittingSeA(Fitting):

def __init__(
self,
descrpt: tf.Tensor,
descrpt: paddle.Tensor,
neuron: List[int] = [120, 120, 120],
resnet_dt: bool = True,
sel_type: Optional[List[int]] = None,
Expand All @@ -49,6 +65,7 @@ def __init__(
precision: str = "default",
uniform_seed: bool = False,
) -> None:
super().__init__(name_scope="DipoleFittingSeA")
"""Constructor."""
self.ntypes = descrpt.get_ntypes()
self.dim_descrpt = descrpt.get_dim_out()
Expand All @@ -62,6 +79,7 @@ def __init__(
)
self.seed = seed
self.uniform_seed = uniform_seed
self.ntypes_spin = 0
self.seed_shift = one_layer_rand_seed_shift()
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
Expand All @@ -71,6 +89,55 @@ def __init__(
self.fitting_net_variables = None
self.mixed_prec = None

type_suffix = ""
suffix = ""
self.one_layers = nn.LayerList()
self.final_layers = nn.LayerList()
ntypes_atom = self.ntypes - self.ntypes_spin
for type_i in range(0, ntypes_atom):
type_i_layers = nn.LayerList()
for ii in range(0, len(self.n_neuron)):

layer_suffix = "layer_" + str(ii) + type_suffix + suffix

if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]:
type_i_layers.append(
OneLayer_deepmd(
self.n_neuron[ii - 1],
self.n_neuron[ii],
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
name=layer_suffix,
seed=self.seed,
use_timestep=self.resnet_dt,
)
)
else:
type_i_layers.append(
OneLayer_deepmd(
self.dim_descrpt,
self.n_neuron[ii],
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
name=layer_suffix,
seed=self.seed,
)
)
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift

self.one_layers.append(type_i_layers)
self.final_layers.append(
OneLayer_deepmd(
self.n_neuron[-1],
self.dim_rot_mat_1,
activation_fn=None,
precision=self.fitting_precision,
name=layer_suffix,
seed=self.seed,
)
)

def get_sel_type(self) -> int:
"""Get selected type."""
return self.sel_type
Expand All @@ -79,79 +146,66 @@ def get_out_size(self) -> int:
"""Get the output size. Should be 3."""
return 3

def _build_lower(self, start_index, natoms, inputs, rot_mat, suffix="", reuse=None):
def _build_lower(
self,
start_index,
natoms,
inputs,
rot_mat,
suffix="",
reuse=None,
type_i=None,
):
# cut-out inputs
inputs_i = tf.slice(inputs, [0, start_index, 0], [-1, natoms, -1])
inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
rot_mat_i = tf.slice(rot_mat, [0, start_index, 0], [-1, natoms, -1])
rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
inputs_i = paddle.slice(
inputs,
[0, 1, 2],
[0, start_index, 0],
[inputs.shape[0], start_index + natoms, inputs.shape[2]],
)
inputs_i = paddle.reshape(inputs_i, [-1, self.dim_descrpt])
rot_mat_i = paddle.slice(
rot_mat,
[0, 1, 2],
[0, start_index, 0],
[rot_mat.shape[0], start_index + natoms, rot_mat.shape[2]],
)
rot_mat_i = paddle.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
layer = inputs_i
for ii in range(0, len(self.n_neuron)):
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]:
layer += one_layer(
layer,
self.n_neuron[ii],
name="layer_" + str(ii) + suffix,
reuse=reuse,
seed=self.seed,
use_timestep=self.resnet_dt,
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
)
layer += self.one_layers[type_i][ii](layer)
else:
layer = one_layer(
layer,
self.n_neuron[ii],
name="layer_" + str(ii) + suffix,
reuse=reuse,
seed=self.seed,
activation_fn=self.fitting_activation_fn,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
)
layer = self.one_layers[type_i][ii](layer)

if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
# (nframes x natoms) x naxis
final_layer = one_layer(
final_layer = self.final_layers[type_i](
layer,
self.dim_rot_mat_1,
activation_fn=None,
name="final_layer" + suffix,
reuse=reuse,
seed=self.seed,
precision=self.fitting_precision,
uniform_seed=self.uniform_seed,
initial_variables=self.fitting_net_variables,
mixed_prec=self.mixed_prec,
final_layer=True,
)

if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
# (nframes x natoms) x 1 * naxis
final_layer = tf.reshape(
final_layer, [tf.shape(inputs)[0] * natoms, 1, self.dim_rot_mat_1]
final_layer = paddle.reshape(
final_layer, [paddle.shape(inputs)[0] * natoms, 1, self.dim_rot_mat_1]
)
# (nframes x natoms) x 1 x 3(coord)
final_layer = tf.matmul(final_layer, rot_mat_i)
final_layer = paddle.matmul(final_layer, rot_mat_i)
# nframes x natoms x 3
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms, 3])
final_layer = paddle.reshape(final_layer, [paddle.shape(inputs)[0], natoms, 3])
return final_layer

@cast_precision
def build(
def forward(
self,
input_d: tf.Tensor,
rot_mat: tf.Tensor,
natoms: tf.Tensor,
input_d: paddle.Tensor,
rot_mat: paddle.Tensor,
natoms: paddle.Tensor,
input_dict: Optional[dict] = None,
reuse: Optional[bool] = None,
suffix: str = "",
) -> tf.Tensor:
) -> paddle.Tensor:
"""Build the computational graph for fitting net.
Parameters
Expand Down Expand Up @@ -183,22 +237,25 @@ def build(
atype = input_dict.get("atype", None)
nframes = input_dict.get("nframes")
start_index = 0
inputs = tf.reshape(input_d, [-1, natoms[0], self.dim_descrpt])
rot_mat = tf.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat])
inputs = paddle.reshape(input_d, [-1, natoms[0], self.dim_descrpt])
rot_mat = paddle.reshape(rot_mat, [-1, natoms[0], self.dim_rot_mat])

if type_embedding is not None:
nloc_mask = tf.reshape(
tf.tile(tf.repeat(self.sel_mask, natoms[2:]), [nframes]), [nframes, -1]
nloc_mask = paddle.reshape(
paddle.tile(paddle.repeat_interleave(self.sel_mask, natoms[2:]), [nframes]),
[nframes, -1],
)
atype_nall = tf.reshape(atype, [-1, natoms[1]])
atype_nall = paddle.reshape(atype, [-1, natoms[1]])
# (nframes x nloc_masked)
self.atype_nloc_masked = tf.reshape(
tf.slice(atype_nall, [0, 0], [-1, natoms[0]])[nloc_mask], [-1]
self.atype_nloc_masked = paddle.reshape(
paddle.slice(atype_nall, [0, 0], [-1, natoms[0]])[nloc_mask], [-1]
) ## lammps will make error
self.nloc_masked = tf.shape(
tf.reshape(self.atype_nloc_masked, [nframes, -1])
self.nloc_masked = paddle.shape(
paddle.reshape(self.atype_nloc_masked, [nframes, -1])
)[1]
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc_masked)
atype_embed = nn.embedding_lookup(
type_embedding, self.atype_nloc_masked
)
else:
atype_embed = None

Expand All @@ -218,40 +275,43 @@ def build(
rot_mat,
suffix="_type_" + str(type_i) + suffix,
reuse=reuse,
type_i=type_i,
)
start_index += natoms[2 + type_i]
# concat the results
outs_list.append(final_layer)
count += 1
outs = tf.concat(outs_list, axis=1)
outs = paddle.concat(outs_list, axis=1)
else:
inputs = tf.reshape(
tf.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[nloc_mask],
inputs = paddle.reshape(
paddle.reshape(inputs, [nframes, natoms[0], self.dim_descrpt])[
nloc_mask
],
[-1, self.dim_descrpt],
)
rot_mat = tf.reshape(
tf.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[
rot_mat = paddle.reshape(
paddle.reshape(rot_mat, [nframes, natoms[0], self.dim_rot_mat_1 * 3])[
nloc_mask
],
[-1, self.dim_rot_mat_1, 3],
)
atype_embed = tf.cast(atype_embed, self.fitting_precision)
atype_embed = paddle.cast(atype_embed, self.fitting_precision)
type_shape = atype_embed.get_shape().as_list()
inputs = tf.concat([inputs, atype_embed], axis=1)
inputs = paddle.concat([inputs, atype_embed], axis=1)
self.dim_descrpt = self.dim_descrpt + type_shape[1]
inputs = tf.reshape(inputs, [nframes, self.nloc_masked, self.dim_descrpt])
rot_mat = tf.reshape(
inputs = paddle.reshape(
inputs, [nframes, self.nloc_masked, self.dim_descrpt]
)
rot_mat = paddle.reshape(
rot_mat, [nframes, self.nloc_masked, self.dim_rot_mat_1 * 3]
)
final_layer = self._build_lower(
0, self.nloc_masked, inputs, rot_mat, suffix=suffix, reuse=reuse
)
# nframes x natoms x 3
outs = tf.reshape(final_layer, [nframes, self.nloc_masked, 3])
outs = paddle.reshape(final_layer, [nframes, self.nloc_masked, 3])

tf.summary.histogram("fitting_net_output", outs)
return tf.reshape(outs, [-1])
# return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3])
return paddle.reshape(outs, [-1])

def init_variables(
self,
Expand Down
19 changes: 14 additions & 5 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import List
Expand All @@ -19,7 +20,9 @@
from deepmd.env import default_tf_session_config
from deepmd.env import paddle
from deepmd.env import tf
from deepmd.fit import dipole
from deepmd.fit import ener
from deepmd.model import DipoleModel
from deepmd.model import EnerModel
from deepmd.utils.argcheck import type_embedding_args
from deepmd.utils.batch_size import AutoBatchSize
Expand Down Expand Up @@ -77,7 +80,9 @@ def __init__(
default_tf_graph: bool = False,
auto_batch_size: Union[bool, int, AutoBatchSize] = False,
):
jdata = j_loader("input.json")
jdata = j_loader(
"input.json" if os.path.exists("input.json") else "dipole_input.json"
)
remove_comment_in_json(jdata)
model_param = j_must_have(jdata, "model")
self.multi_task_mode = "fitting_net_dict" in model_param
Expand Down Expand Up @@ -140,7 +145,12 @@ def __init__(
if fitting_type == "ener":
fitting_param["spin"] = spin
fitting_param.pop("type", None)
fitting = ener.EnerFitting(**fitting_param)
fitting = ener.EnerFitting(**fitting_param)
elif fitting_type == "dipole":
fitting_param.pop("type", None)
fitting = dipole.DipoleFittingSeA(**fitting_param)
else:
raise NotImplementedError()
else:
self.fitting_dict = {}
self.fitting_type_dict = {}
Expand Down Expand Up @@ -216,7 +226,6 @@ def __init__(
)

elif self.fitting_type == "dipole":
raise NotImplementedError()
self.model = DipoleModel(
descrpt,
fitting,
Expand Down Expand Up @@ -352,7 +361,7 @@ def __init__(
@property
@lru_cache(maxsize=None)
def model_type(self) -> str:
return "ener"
return self.model.model_type
"""Get type of model.
:type:str
Expand Down Expand Up @@ -411,7 +420,7 @@ def _graph_compatable(self) -> bool:

def _get_value(
self, tensor_name: str, attr_name: Optional[str] = None
) -> tf.Tensor:
) -> paddle.Tensor:
"""Get TF graph tensor and assign it to class namespace.
Parameters
----------
Expand Down
Loading

0 comments on commit 382503b

Please sign in to comment.