Skip to content

Commit

Permalink
[Paddle Backend] Add mixed precision training for se_e2_a_mixed_prec(…
Browse files Browse the repository at this point in the history
…revert code format) (deepmodeling#3096)

Code formatting for deepmodeling#3030 

---------

Signed-off-by: HydrogenSulfate <[email protected]>
Co-authored-by: zhouwei25 <[email protected]>
Co-authored-by: JiabinYang <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: Zhanlue Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
6 people authored Dec 30, 2023
1 parent bb28e11 commit 132ba0e
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 37 deletions.
36 changes: 20 additions & 16 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@
import numpy as np
import tensorflow
import yaml
from tensorflow.python.framework import (
tensor_util,
)

from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
GLOBAL_PD_FLOAT_PRECISION,
GLOBAL_TF_FLOAT_PRECISION,
op_module,
paddle,
tf,
Expand Down Expand Up @@ -317,7 +313,7 @@ def get_activation_func(
return ACTIVATION_FN_DICT[activation_fn]


def get_precision(precision: "_PRECISION") -> Any:
def get_precision(precision: "_PRECISION") -> paddle.dtype:
"""Convert str to TF DType constant.
Parameters
Expand Down Expand Up @@ -392,8 +388,8 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:


def safe_cast_tensor(
input: tf.Tensor, from_precision: tf.DType, to_precision: tf.DType
) -> tf.Tensor:
input: paddle.Tensor, from_precision: paddle.dtype, to_precision: paddle.dtype
) -> paddle.Tensor:
"""Convert a Tensor from a precision to another precision.
If input is not a Tensor or without the specific precision, the method will not
Expand All @@ -413,8 +409,16 @@ def safe_cast_tensor(
tf.Tensor
casted Tensor
"""
if tensor_util.is_tensor(input) and input.dtype == from_precision:
return tf.cast(input, to_precision)
assert isinstance(
from_precision, paddle.dtype
), f"type of from_precision is {type(from_precision)}"
assert isinstance(
to_precision, paddle.dtype
), f"type of from_precision is {type(to_precision)}"
if paddle.is_tensor(input):
if input.dtype == from_precision and input.dtype != to_precision:
return paddle.cast(input, to_precision)
return input
return input


Expand All @@ -425,13 +429,13 @@ def cast_precision(func: Callable) -> Callable:
The decorator should be used in a classmethod.
The decorator will do the following thing:
(1) It casts input Tensors from `GLOBAL_TF_FLOAT_PRECISION`
(1) It casts input Tensors from `GLOBAL_PD_FLOAT_PRECISION`
to precision defined by property `precision`.
(2) It casts output Tensors from `precision` to
`GLOBAL_TF_FLOAT_PRECISION`.
`GLOBAL_PD_FLOAT_PRECISION`.
(3) It checks inputs and outputs and only casts when
input or output is a Tensor and its dtype matches
`GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively.
`GLOBAL_PD_FLOAT_PRECISION` and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.
Expand Down Expand Up @@ -459,22 +463,22 @@ def wrapper(self, *args, **kwargs):
returned_tensor = func(
self,
*[
safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision)
safe_cast_tensor(vv, GLOBAL_PD_FLOAT_PRECISION, self.precision)
for vv in args
],
**{
kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision)
kk: safe_cast_tensor(vv, GLOBAL_PD_FLOAT_PRECISION, self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION)
safe_cast_tensor(vv, self.precision, GLOBAL_PD_FLOAT_PRECISION)
for vv in returned_tensor
)
else:
return safe_cast_tensor(
returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION
returned_tensor, self.precision, GLOBAL_PD_FLOAT_PRECISION
)

return wrapper
Expand Down
14 changes: 11 additions & 3 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
mixed_prec: Optional[dict] = None,
) -> None:
"""Constructor."""
super().__init__()
Expand All @@ -162,7 +163,9 @@ def __init__(
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.exclude_types = set()
if mixed_prec is not None:
self.filter_precision = get_precision(mixed_prec["output_prec"])
self.exclude_types = set() # empty
for tt in exclude_types:
assert len(tt) == 2
self.exclude_types.add((tt[0], tt[1]))
Expand Down Expand Up @@ -205,7 +208,7 @@ def __init__(
self.davg = None
# self.compress = False
# self.embedding_net_variables = None
# self.mixed_prec = None
self.mixed_prec = mixed_prec
# self.place_holders = {}
# self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a)
self.avg_zero = paddle.zeros(
Expand All @@ -227,6 +230,7 @@ def __init__(
self.seed,
self.trainable,
name="filter_type_" + str(type_input) + str(type_i),
mixed_prec=self.mixed_prec,
)
)
nets.append(paddle.nn.LayerList(layer))
Expand All @@ -235,7 +239,6 @@ def __init__(

self.compress = False
self.embedding_net_variables = None
self.mixed_prec = None
self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a) # like a mask

self.original_sel = None
Expand Down Expand Up @@ -1194,3 +1197,8 @@ def init_variables(
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel

@property
def precision(self) -> paddle.dtype:
"""Precision of filter network."""
return self.filter_precision
2 changes: 1 addition & 1 deletion deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def update_one_sel(jdata, descriptor):
return descriptor
rcut = descriptor["rcut"]
tmp_sel = get_sel(jdata, rcut, one_type=descriptor["type"] in ("se_atten",))
sel = descriptor["sel"] # [46, 92]
sel = descriptor["sel"]
if isinstance(sel, int):
# convert to list and finnally convert back to int
sel = [sel]
Expand Down
15 changes: 14 additions & 1 deletion deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
spin: Optional[Spin] = None,
mixed_prec: Optional[dict] = None,
) -> None:
super().__init__(name_scope="EnerFitting")
"""Constructor."""
Expand Down Expand Up @@ -160,11 +161,14 @@ def __init__(
self.seed = seed
self.uniform_seed = uniform_seed
self.spin = spin
self.mixed_prec = mixed_prec
self.ntypes_spin = self.spin.get_ntypes_spin() if self.spin is not None else 0
self.seed_shift = one_layer_rand_seed_shift()
self.tot_ener_zero = tot_ener_zero
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
if mixed_prec is not None:
self.filter_precision = get_precision(mixed_prec["output_prec"])
self.trainable = trainable
if self.trainable is None:
self.trainable = [True for ii in range(len(self.n_neuron) + 1)]
Expand Down Expand Up @@ -205,7 +209,7 @@ def __init__(
self.aparam_inv_std = None

self.fitting_net_variables = None
self.mixed_prec = None
self.mixed_prec = mixed_prec
self.layer_name = layer_name
if self.layer_name is not None:
assert isinstance(self.layer_name, list), "layer_name should be a list"
Expand Down Expand Up @@ -237,6 +241,7 @@ def __init__(
seed=self.seed,
use_timestep=self.resnet_dt,
trainable=self.trainable[ii],
mixed_prec=self.mixed_prec,
)
)
else:
Expand All @@ -249,6 +254,7 @@ def __init__(
name=layer_suffix,
seed=self.seed,
trainable=self.trainable[ii],
mixed_prec=self.mixed_prec,
)
)
if (not self.uniform_seed) and (self.seed is not None):
Expand All @@ -264,6 +270,8 @@ def __init__(
name=layer_suffix,
seed=self.seed,
trainable=self.trainable[-1],
mixed_prec=self.mixed_prec,
final_layer=True,
)
)

Expand Down Expand Up @@ -874,3 +882,8 @@ def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec["output_prec"])

@property
def precision(self) -> paddle.dtype:
"""Precision of filter network."""
return self.fitting_precision
2 changes: 1 addition & 1 deletion deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def forward(
# atom_ener = tf.reshape(inv_sw_lambda, [-1]) * atom_ener
# energy_raw = tab_atom_ener + atom_ener
else:
energy_raw = atom_ener
energy_raw = atom_ener # [1, all_atoms]

nloc_atom = (
natoms[0]
Expand Down
13 changes: 11 additions & 2 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def __init__(self, jdata, run_opt, is_compress=False):
self.is_compress = is_compress

def _init_param(self, jdata):
tr_data = jdata["training"]
self.mixed_prec = tr_data.get("mixed_precision", None)
if self.mixed_prec is not None:
log.info("mixed precision is enabled")
# model config
model_param = j_must_have(jdata, "model")
self.multi_task_mode = "fitting_net_dict" in model_param
Expand Down Expand Up @@ -148,6 +152,9 @@ def _init_param(self, jdata):
if descrpt_param["type"] in ["se_e2_a", "se_a", "se_e2_r", "se_r", "hybrid"]:
descrpt_param["spin"] = self.spin
descrpt_param.pop("type")
descrpt_param["mixed_prec"] = self.mixed_prec
if descrpt_param["mixed_prec"] is not None:
descrpt_param["precision"]: str = self.mixed_prec["output_prec"]
self.descrpt = deepmd.descriptor.se_a.DescrptSeA(**descrpt_param)

# fitting net
Expand All @@ -158,8 +165,12 @@ def _init_param(self, jdata):
if fitting_type == "ener":
fitting_param["spin"] = self.spin
fitting_param.pop("type")
fitting_param["mixed_prec"] = self.mixed_prec
if fitting_param["mixed_prec"] is not None:
fitting_param["precision"]: str = self.mixed_prec["output_prec"]
self.fitting = ener.EnerFitting(**fitting_param)
else:
raise NotImplementedError("multi-task mode is not supported")
self.fitting_dict = {}
self.fitting_type_dict = {}
self.nfitting = len(fitting_param)
Expand Down Expand Up @@ -380,7 +391,6 @@ def loss_init(_loss_param, _fitting_type, _fitting, _lr) -> EnerStdLoss:
)

# training
tr_data = jdata["training"]
self.fitting_weight = tr_data.get("fitting_weight", None)
if self.multi_task_mode:
self.fitting_key_list = []
Expand All @@ -401,7 +411,6 @@ def loss_init(_loss_param, _fitting_type, _fitting, _lr) -> EnerStdLoss:
self.tensorboard = self.run_opt.is_chief and tr_data.get("tensorboard", False)
self.tensorboard_log_dir = tr_data.get("tensorboard_log_dir", "log")
self.tensorboard_freq = tr_data.get("tensorboard_freq", 1)
self.mixed_prec = tr_data.get("mixed_precision", None)
if self.mixed_prec is not None:
if (
self.mixed_prec["compute_prec"] not in ("float16", "bfloat16")
Expand Down
Loading

0 comments on commit 132ba0e

Please sign in to comment.