Skip to content

Commit

Permalink
add (manually)mixed precision training mode(tested with se_e2_a_mixed…
Browse files Browse the repository at this point in the history
…_prec
  • Loading branch information
HydrogenSulfate committed Dec 4, 2023
1 parent 17223e7 commit f745cb0
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 28 deletions.
26 changes: 17 additions & 9 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def get_activation_func(
return ACTIVATION_FN_DICT[activation_fn]


def get_precision(precision: "_PRECISION") -> Any:
def get_precision(precision: "_PRECISION") -> paddle.dtype:
"""Convert str to TF DType constant.
Parameters
Expand Down Expand Up @@ -380,8 +380,8 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:


def safe_cast_tensor(
input: tf.Tensor, from_precision: tf.DType, to_precision: tf.DType
) -> tf.Tensor:
input: paddle.Tensor, from_precision: paddle.dtype, to_precision: paddle.dtype
) -> paddle.Tensor:
"""Convert a Tensor from a precision to another precision.
If input is not a Tensor or without the specific precision, the method will not
Expand All @@ -401,8 +401,16 @@ def safe_cast_tensor(
tf.Tensor
casted Tensor
"""
if tensor_util.is_tensor(input) and input.dtype == from_precision:
return tf.cast(input, to_precision)
assert isinstance(
from_precision, paddle.dtype
), f"type of from_precision is {type(from_precision)}"
assert isinstance(
to_precision, paddle.dtype
), f"type of from_precision is {type(to_precision)}"
if paddle.is_tensor(input):
if input.dtype == from_precision and input.dtype != to_precision:
return paddle.cast(input, to_precision)
return input
return input


Expand Down Expand Up @@ -447,22 +455,22 @@ def wrapper(self, *args, **kwargs):
returned_tensor = func(
self,
*[
safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision)
safe_cast_tensor(vv, GLOBAL_PD_FLOAT_PRECISION, self.precision)
for vv in args
],
**{
kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision)
kk: safe_cast_tensor(vv, GLOBAL_PD_FLOAT_PRECISION, self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
return tuple(
safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION)
safe_cast_tensor(vv, self.precision, GLOBAL_PD_FLOAT_PRECISION)
for vv in returned_tensor
)
else:
return safe_cast_tensor(
returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION
returned_tensor, self.precision, GLOBAL_PD_FLOAT_PRECISION
)

return wrapper
Expand Down
15 changes: 12 additions & 3 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def __init__(
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
mixed_prec: Optional[dict] = None,
) -> None:
"""Constructor."""
super().__init__()
Expand All @@ -160,6 +161,8 @@ def __init__(
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
if mixed_prec is not None:
self.filter_precision = get_precision(mixed_prec["output_prec"])
self.exclude_types = set() # empty
for tt in exclude_types:
assert len(tt) == 2
Expand Down Expand Up @@ -199,7 +202,7 @@ def __init__(
self.davg = None
# self.compress = False
# self.embedding_net_variables = None
# self.mixed_prec = None
self.mixed_prec = mixed_prec
# self.place_holders = {}
# self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a)
self.avg_zero = paddle.zeros(
Expand All @@ -222,6 +225,7 @@ def __init__(
self.seed,
self.trainable,
name="filter_type_" + str(type_input) + str(type_i),
mixed_prec=self.mixed_prec,
)
)
nets.append(paddle.nn.LayerList(layer))
Expand All @@ -230,7 +234,7 @@ def __init__(

self.compress = False
self.embedding_net_variables = None
self.mixed_prec = None
# self.mixed_prec = None
# self.place_holders = {}
self.nei_type = np.repeat(np.arange(self.ntypes), self.sel_a) # like a mask

Expand Down Expand Up @@ -1074,7 +1078,7 @@ def _filter_lower(
transpose_x=True,
) # 得到(R_i).T*g_i,即D_i表达式的右半部分

# @cast_precision
@cast_precision
def _filter(
self,
inputs: paddle.Tensor, # [1, 原子个数(64或128), 552(nnei*4)]
Expand Down Expand Up @@ -1313,3 +1317,8 @@ def init_variables(
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel

@property
def precision(self) -> paddle.dtype:
"""Precision of filter network."""
return self.filter_precision
17 changes: 15 additions & 2 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
layer_name: Optional[List[Optional[str]]] = None,
use_aparam_as_mask: bool = False,
spin: Optional[Spin] = None,
mixed_prec: Optional[dict] = None,
) -> None:
super().__init__(name_scope="EnerFitting")
"""Constructor."""
Expand Down Expand Up @@ -149,11 +150,14 @@ def __init__(
self.seed = seed
self.uniform_seed = uniform_seed
self.spin = spin
self.mixed_prec = mixed_prec
self.ntypes_spin = self.spin.get_ntypes_spin() if self.spin is not None else 0
self.seed_shift = one_layer_rand_seed_shift()
self.tot_ener_zero = tot_ener_zero
self.fitting_activation_fn = get_activation_func(activation_function)
self.fitting_precision = get_precision(precision)
if mixed_prec is not None:
self.filter_precision = get_precision(mixed_prec["output_prec"])
self.trainable = trainable
if self.trainable is None:
self.trainable = [True for ii in range(len(self.n_neuron) + 1)]
Expand Down Expand Up @@ -194,7 +198,7 @@ def __init__(
self.aparam_inv_std = None

self.fitting_net_variables = None
self.mixed_prec = None
self.mixed_prec = mixed_prec
self.layer_name = layer_name
if self.layer_name is not None:
assert isinstance(self.layer_name, list), "layer_name should be a list"
Expand Down Expand Up @@ -226,6 +230,7 @@ def __init__(
seed=self.seed,
use_timestep=self.resnet_dt,
trainable=self.trainable[ii],
mixed_prec=self.mixed_prec,
)
)
else:
Expand All @@ -238,6 +243,7 @@ def __init__(
name=layer_suffix,
seed=self.seed,
trainable=self.trainable[ii],
mixed_prec=self.mixed_prec,
)
)
if (not self.uniform_seed) and (self.seed is not None):
Expand All @@ -253,6 +259,8 @@ def __init__(
name=layer_suffix,
seed=self.seed,
trainable=self.trainable[-1],
mixed_prec=self.mixed_prec,
final_layer=True,
)
)

Expand Down Expand Up @@ -431,7 +439,7 @@ def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None:
def _compute_std(self, sumv2, sumv, sumn):
return np.sqrt(sumv2 / sumn - np.multiply(sumv / sumn, sumv / sumn))

# @cast_precision
@cast_precision
def _build_lower(
self,
start_index,
Expand Down Expand Up @@ -911,3 +919,8 @@ def enable_mixed_precision(self, mixed_prec: Optional[dict] = None) -> None:
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec["output_prec"])

@property
def precision(self) -> paddle.dtype:
"""Precision of filter network."""
return self.fitting_precision
15 changes: 13 additions & 2 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def __init__(self, jdata, run_opt, is_compress=False):
self.is_compress = is_compress

def _init_param(self, jdata):
tr_data = jdata["training"]
self.mixed_prec = tr_data.get("mixed_precision", None)
if self.mixed_prec is not None:
log.info("mixed precision is enabled")
# model config
model_param = j_must_have(jdata, "model")
self.multi_task_mode = "fitting_net_dict" in model_param
Expand Down Expand Up @@ -126,6 +130,9 @@ def _init_param(self, jdata):
if descrpt_param["type"] in ["se_e2_a", "se_a", "se_e2_r", "se_r", "hybrid"]:
descrpt_param["spin"] = self.spin
descrpt_param.pop("type")
descrpt_param["mixed_prec"] = self.mixed_prec
if descrpt_param["mixed_prec"] is not None:
descrpt_param["precision"]: str = self.mixed_prec["output_prec"]
self.descrpt = deepmd.descriptor.se_a.DescrptSeA(**descrpt_param)

# fitting net
Expand All @@ -136,8 +143,12 @@ def _init_param(self, jdata):
if fitting_type == "ener":
fitting_param["spin"] = self.spin
fitting_param.pop("type")
fitting_param["mixed_prec"] = self.mixed_prec
if fitting_param["mixed_prec"] is not None:
fitting_param["precision"]: str = self.mixed_prec["output_prec"]
self.fitting = ener.EnerFitting(**fitting_param)
else:
raise NotImplementedError("multi-task mode is not supported")
self.fitting_dict = {}
self.fitting_type_dict = {}
self.nfitting = len(fitting_param)
Expand Down Expand Up @@ -358,7 +369,7 @@ def loss_init(_loss_param, _fitting_type, _fitting, _lr) -> EnerStdLoss:
)

# training
tr_data = jdata["training"]
# tr_data = jdata["training"]
self.fitting_weight = tr_data.get("fitting_weight", None)
if self.multi_task_mode:
self.fitting_key_list = []
Expand All @@ -379,7 +390,7 @@ def loss_init(_loss_param, _fitting_type, _fitting, _lr) -> EnerStdLoss:
self.tensorboard = self.run_opt.is_chief and tr_data.get("tensorboard", False)
self.tensorboard_log_dir = tr_data.get("tensorboard_log_dir", "log")
self.tensorboard_freq = tr_data.get("tensorboard_freq", 1)
self.mixed_prec = tr_data.get("mixed_precision", None)
# self.mixed_prec = tr_data.get("mixed_precision", None)
if self.mixed_prec is not None:
if (
self.mixed_prec["compute_prec"] not in ("float16", "bfloat16")
Expand Down
Loading

0 comments on commit f745cb0

Please sign in to comment.