diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index ae557326ff..2a68b793d4 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -78,12 +78,13 @@ def __init__( activation_function: Optional[str] = None, resnet: bool = False, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ) -> None: prec = PRECISION_DICT[precision.lower()] self.precision = precision # only use_timestep when skip connection is established. use_timestep = use_timestep and (num_out == num_in or num_out == num_in * 2) - rng = np.random.default_rng() + rng = np.random.default_rng(seed) self.w = rng.normal(size=(num_in, num_out)).astype(prec) self.b = rng.normal(size=(num_out,)).astype(prec) if bias else None self.idt = rng.normal(size=(num_out,)).astype(prec) if use_timestep else None @@ -313,6 +314,7 @@ def __init__( uni_init: bool = True, trainable: bool = True, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ) -> None: self.eps = eps self.uni_init = uni_init @@ -325,6 +327,7 @@ def __init__( activation_function=None, resnet=False, precision=precision, + seed=seed, ) self.w = self.w.squeeze(0) # keep the weight shape to be [num_in] if self.uni_init: @@ -569,6 +572,7 @@ def __init__( activation_function: str = "tanh", resnet_dt: bool = False, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ): layers = [] i_in = in_dim @@ -583,6 +587,7 @@ def __init__( activation_function=activation_function, resnet=True, precision=precision, + seed=seed, ).serialize() ) i_in = i_ot @@ -669,6 +674,7 @@ def __init__( resnet_dt: bool = False, precision: str = DEFAULT_PRECISION, bias_out: bool = True, + seed: Optional[int] = None, ): super().__init__( in_dim, @@ -688,6 +694,7 @@ def __init__( activation_function=None, resnet=False, precision=precision, + seed=seed, ) ) self.out_dim = out_dim diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index b80d2d4c38..4ab39465dc 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -172,6 +172,8 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module): Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'. Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'. The default value is `None`, which means the `tebd_input_mode` setting will be used instead. + seed: int, Optional + Random seed for parameter initialization. use_econf_tebd: bool, Optional Whether to use electronic configuration type embedding. type_map: List[str], Optional @@ -225,12 +227,12 @@ def __init__( smooth_type_embedding: bool = True, type_one_side: bool = False, stripped_type_embedding: Optional[bool] = None, + seed: Optional[int] = None, use_econf_tebd: bool = False, type_map: Optional[List[str]] = None, # not implemented spin=None, type: Optional[str] = None, - seed: Optional[int] = None, old_impl: bool = False, ): super().__init__() @@ -275,6 +277,7 @@ def __init__( env_protection=env_protection, trainable_ln=trainable_ln, ln_eps=ln_eps, + seed=seed, old_impl=old_impl, ) self.use_econf_tebd = use_econf_tebd @@ -283,6 +286,7 @@ def __init__( ntypes, tebd_dim, precision=precision, + seed=seed, use_econf_tebd=use_econf_tebd, type_map=type_map, ) diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index 600930bb7a..678b797e6c 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -103,7 +103,7 @@ def __init__( trainable : bool, optional If the parameters are trainable. seed : int, optional - (Unused yet) Random seed for parameter initialization. + Random seed for parameter initialization. add_tebd_to_repinit_out : bool, optional Whether to add type embedding to the output representation from repinit before inputting it into repformer. use_econf_tebd : bool, Optional @@ -160,6 +160,7 @@ def init_subclass_params(sub_data, sub_class): resnet_dt=self.repinit_args.resnet_dt, smooth=smooth, type_one_side=self.repinit_args.type_one_side, + seed=seed, ) self.repformers = DescrptBlockRepformers( self.repformer_args.rcut, @@ -194,6 +195,7 @@ def init_subclass_params(sub_data, sub_class): precision=precision, trainable_ln=self.repformer_args.trainable_ln, ln_eps=self.repformer_args.ln_eps, + seed=seed, old_impl=old_impl, ) self.use_econf_tebd = use_econf_tebd @@ -202,6 +204,7 @@ def init_subclass_params(sub_data, sub_class): ntypes, self.repinit_args.tebd_dim, precision=precision, + seed=seed, use_econf_tebd=self.use_econf_tebd, type_map=type_map, ) @@ -222,6 +225,7 @@ def init_subclass_params(sub_data, sub_class): bias=False, precision=precision, init="glorot", + seed=seed, ) self.tebd_transform = None if self.add_tebd_to_repinit_out: @@ -230,6 +234,7 @@ def init_subclass_params(sub_data, sub_class): self.repformers.dim_in, bias=False, precision=precision, + seed=seed, ) assert self.repinit.rcut > self.repformers.rcut assert self.repinit.sel[0] > self.repformers.sel[0] diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 8397b4b421..3f377f9de5 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -7,6 +7,10 @@ import torch import torch.nn as nn +from deepmd.pt.model.network.init import ( + constant_, + normal_, +) from deepmd.pt.model.network.layernorm import ( LayerNorm, ) @@ -21,6 +25,7 @@ ) from deepmd.pt.utils.utils import ( ActivationFn, + get_generator, to_numpy_array, to_torch_tensor, ) @@ -35,6 +40,7 @@ def get_residual( _mode: str = "norm", trainable: bool = True, precision: str = "float64", + seed: Optional[int] = None, ) -> torch.Tensor: r""" Get residual tensor for one update vector. @@ -53,15 +59,18 @@ def get_residual( Whether the residual tensor is trainable. precision The precision of the residual tensor. + seed : int, optional + Random seed for parameter initialization. """ + random_generator = get_generator(seed) residual = nn.Parameter( data=torch.zeros(_dim, dtype=PRECISION_DICT[precision], device=env.DEVICE), requires_grad=trainable, ) if _mode == "norm": - nn.init.normal_(residual.data, std=_scale) + normal_(residual.data, std=_scale, generator=random_generator) elif _mode == "const": - nn.init.constant_(residual.data, val=_scale) + constant_(residual.data, val=_scale) else: raise RuntimeError(f"Unsupported initialization mode '{_mode}'!") return residual @@ -147,6 +156,7 @@ def __init__( smooth: bool = True, attnw_shift: float = 20.0, precision: str = "float64", + seed: Optional[int] = None, ): """Return neighbor-wise multi-head self-attention maps, with gate mechanism.""" super().__init__() @@ -154,7 +164,11 @@ def __init__( self.hidden_dim = hidden_dim self.head_num = head_num self.mapqk = MLPLayer( - input_dim, hidden_dim * 2 * head_num, bias=False, precision=precision + input_dim, + hidden_dim * 2 * head_num, + bias=False, + precision=precision, + seed=seed, ) self.has_gate = has_gate self.smooth = smooth @@ -267,14 +281,21 @@ def __init__( input_dim: int, head_num: int, precision: str = "float64", + seed: Optional[int] = None, ): super().__init__() self.input_dim = input_dim self.head_num = head_num self.mapv = MLPLayer( - input_dim, input_dim * head_num, bias=False, precision=precision + input_dim, + input_dim * head_num, + bias=False, + precision=precision, + seed=seed, + ) + self.head_map = MLPLayer( + input_dim * head_num, input_dim, precision=precision, seed=seed ) - self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.precision = precision def forward( @@ -342,11 +363,14 @@ def __init__( input_dim: int, head_num: int, precision: str = "float64", + seed: Optional[int] = None, ): super().__init__() self.input_dim = input_dim self.head_num = head_num - self.head_map = MLPLayer(head_num, 1, bias=False, precision=precision) + self.head_map = MLPLayer( + head_num, 1, bias=False, precision=precision, seed=seed + ) self.precision = precision def forward( @@ -412,21 +436,29 @@ def __init__( smooth: bool = True, attnw_shift: float = 20.0, precision: str = "float64", + seed: Optional[int] = None, ): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.head_num = head_num self.mapq = MLPLayer( - input_dim, hidden_dim * 1 * head_num, bias=False, precision=precision + input_dim, + hidden_dim * 1 * head_num, + bias=False, + precision=precision, + seed=seed, ) self.mapkv = MLPLayer( input_dim, (hidden_dim + input_dim) * head_num, bias=False, precision=precision, + seed=seed, + ) + self.head_map = MLPLayer( + input_dim * head_num, input_dim, precision=precision, seed=seed ) - self.head_map = MLPLayer(input_dim * head_num, input_dim, precision=precision) self.smooth = smooth self.attnw_shift = attnw_shift self.precision = precision @@ -557,6 +589,7 @@ def __init__( precision: str = "float64", trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, + seed: Optional[int] = None, ): super().__init__() self.epsilon = 1e-4 # protection of 1./nnei @@ -594,6 +627,7 @@ def __init__( self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.precision = precision + self.seed = seed assert update_residual_init in [ "norm", @@ -612,11 +646,12 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) g1_in_dim = self.cal_1_dim(g1_dim, g2_dim, self.axis_neuron) - self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision) + self.linear1 = MLPLayer(g1_in_dim, g1_dim, precision=precision, seed=seed) self.linear2 = None self.proj_g1g2 = None self.proj_g1g1g2 = None @@ -627,7 +662,7 @@ def __init__( self.loc_attn = None if self.update_chnnl_2: - self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision) + self.linear2 = MLPLayer(g2_dim, g2_dim, precision=precision, seed=seed) if self.update_style == "res_residual": self.g2_residual.append( get_residual( @@ -635,12 +670,17 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) if self.update_g1_has_conv: - self.proj_g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) + self.proj_g1g2 = MLPLayer( + g1_dim, g2_dim, bias=False, precision=precision, seed=seed + ) if self.update_g2_has_g1g1: - self.proj_g1g1g2 = MLPLayer(g1_dim, g2_dim, bias=False, precision=precision) + self.proj_g1g1g2 = MLPLayer( + g1_dim, g2_dim, bias=False, precision=precision, seed=seed + ) if self.update_style == "res_residual": self.g2_residual.append( get_residual( @@ -648,6 +688,7 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) if self.update_g2_has_attn or self.update_h2: @@ -658,13 +699,18 @@ def __init__( attn2_has_gate, self.smooth, precision=precision, + seed=seed, ) if self.update_g2_has_attn: self.attn2_mh_apply = Atten2MultiHeadApply( - g2_dim, attn2_nhead, precision=precision + g2_dim, attn2_nhead, precision=precision, seed=seed ) self.attn2_lm = LayerNorm( - g2_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + g2_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=seed, ) if self.update_style == "res_residual": self.g2_residual.append( @@ -673,12 +719,13 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) if self.update_h2: self.attn2_ev_apply = Atten2EquiVarApply( - g2_dim, attn2_nhead, precision=precision + g2_dim, attn2_nhead, precision=precision, seed=seed ) if self.update_style == "res_residual": self.h2_residual.append( @@ -687,11 +734,17 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) if self.update_g1_has_attn: self.loc_attn = LocalAtten( - g1_dim, attn1_hidden, attn1_nhead, self.smooth, precision=precision + g1_dim, + attn1_hidden, + attn1_nhead, + self.smooth, + precision=precision, + seed=seed, ) if self.update_style == "res_residual": self.g1_residual.append( @@ -700,6 +753,7 @@ def __init__( self.update_residual, self.update_residual_init, precision=precision, + seed=seed, ) ) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index b8adc0d71e..a66693653e 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -101,6 +101,7 @@ def __init__( precision: str = "float64", trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, + seed: Optional[int] = None, old_impl: bool = False, ): r""" @@ -180,6 +181,8 @@ def __init__( Whether to use trainable shift and scale weights in layer normalization. ln_eps : float, optional The epsilon value for layer normalization. + seed : int, optional + Random seed for parameter initialization. """ super().__init__() self.rcut = rcut @@ -223,9 +226,10 @@ def __init__( self.trainable_ln = trainable_ln self.ln_eps = ln_eps self.epsilon = 1e-4 + self.seed = seed self.old_impl = old_impl - self.g2_embd = MLPLayer(1, self.g2_dim, precision=precision) + self.g2_embd = MLPLayer(1, self.g2_dim, precision=precision, seed=seed) layers = [] for ii in range(nlayers): if self.old_impl: @@ -287,6 +291,7 @@ def __init__( trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, precision=precision, + seed=seed, ) ) self.layers = torch.nn.ModuleList(layers) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 50393e8a03..0035eddba6 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -82,6 +82,7 @@ def __init__( env_protection: float = 0.0, old_impl: bool = False, type_one_side: bool = True, + seed: Optional[int] = None, **kwargs, ): super().__init__() @@ -99,6 +100,7 @@ def __init__( env_protection=env_protection, old_impl=old_impl, type_one_side=type_one_side, + seed=seed, **kwargs, ) @@ -328,6 +330,7 @@ def __init__( old_impl: bool = False, type_one_side: bool = True, trainable: bool = True, + seed: Optional[int] = None, **kwargs, ): """Construct an embedding net of type `se_a`. @@ -354,6 +357,7 @@ def __init__( self.env_protection = env_protection self.ntypes = len(sel) self.type_one_side = type_one_side + self.seed = seed # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) @@ -397,6 +401,7 @@ def __init__( activation_function=self.activation_function, precision=self.precision, resnet_dt=self.resnet_dt, + seed=self.seed, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index d87ad76e3c..2ffcb62ff9 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -83,6 +83,7 @@ def __init__( env_protection: float = 0.0, trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, + seed: Optional[int] = None, type: Optional[str] = None, old_impl: bool = False, ): @@ -148,6 +149,8 @@ def __init__( Whether to normalize the hidden vectors in attention weights calculation. temperature : float If not None, the scaling of attention weights is `temperature` itself. + seed : int, Optional + Random seed for parameter initialization. """ super().__init__() del type @@ -174,6 +177,7 @@ def __init__( self.type_one_side = type_one_side self.env_protection = env_protection self.trainable_ln = trainable_ln + self.seed = seed # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -223,6 +227,7 @@ def __init__( ln_eps=self.ln_eps, smooth=self.smooth, precision=self.precision, + seed=self.seed, ) wanted_shape = (self.ntypes, self.nnei, 4) @@ -266,6 +271,7 @@ def __init__( activation_function=self.activation_function, precision=self.precision, resnet_dt=self.resnet_dt, + seed=self.seed, ) self.filter_layers = filter_layers if self.tebd_input_mode in ["strip"]: @@ -278,6 +284,7 @@ def __init__( activation_function=self.activation_function, precision=self.precision, resnet_dt=self.resnet_dt, + seed=self.seed, ) self.filter_layers_strip = filter_layers_strip self.stats = None @@ -591,6 +598,7 @@ def __init__( ln_eps: float = 1e-5, smooth: bool = True, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ): """Construct a neighbor-wise attention net.""" super().__init__() @@ -607,6 +615,7 @@ def __init__( self.ln_eps = ln_eps self.smooth = smooth self.precision = precision + self.seed = seed self.network_type = NeighborGatedAttentionLayer attention_layers = [] for i in range(self.layer_num): @@ -624,6 +633,7 @@ def __init__( ln_eps=ln_eps, smooth=smooth, precision=precision, + seed=seed, ) ) self.attention_layers = nn.ModuleList(attention_layers) @@ -731,6 +741,7 @@ def __init__( trainable_ln: bool = True, ln_eps: float = 1e-5, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ): """Construct a neighbor-wise attention layer.""" super().__init__() @@ -745,6 +756,7 @@ def __init__( self.precision = precision self.trainable_ln = trainable_ln self.ln_eps = ln_eps + self.seed = seed self.attention_layer = GatedAttentionLayer( nnei, embed_dim, @@ -756,9 +768,14 @@ def __init__( temperature=temperature, smooth=smooth, precision=precision, + seed=seed, ) self.attn_layer_norm = LayerNorm( - self.embed_dim, eps=ln_eps, trainable=trainable_ln, precision=precision + self.embed_dim, + eps=ln_eps, + trainable=trainable_ln, + precision=precision, + seed=seed, ) def forward( @@ -831,6 +848,7 @@ def __init__( bias: bool = True, smooth: bool = True, precision: str = DEFAULT_PRECISION, + seed: Optional[int] = None, ): """Construct a multi-head neighbor-wise attention net.""" super().__init__() @@ -847,6 +865,7 @@ def __init__( self.scaling_factor = scaling_factor self.temperature = temperature self.precision = precision + self.seed = seed self.scaling = ( (self.head_dim * scaling_factor) ** -0.5 if temperature is None @@ -861,6 +880,7 @@ def __init__( bavg=0.0, stddev=1.0, precision=precision, + seed=seed, ) self.out_proj = MLPLayer( hidden_dim, @@ -870,6 +890,7 @@ def __init__( bavg=0.0, stddev=1.0, precision=precision, + seed=seed, ) def forward( diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 4ac510e2a7..340fd0c02a 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -67,6 +67,7 @@ def __init__( env_protection: float = 0.0, old_impl: bool = False, trainable: bool = True, + seed: Optional[int] = None, **kwargs, ): super().__init__() @@ -82,6 +83,7 @@ def __init__( self.old_impl = False # this does not support old implementation. self.exclude_types = exclude_types self.ntypes = len(sel) + self.seed = seed # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection @@ -113,6 +115,7 @@ def __init__( activation_function=self.activation_function, precision=self.precision, resnet_dt=self.resnet_dt, + seed=self.seed, ) self.filter_layers = filter_layers self.stats = None diff --git a/deepmd/pt/model/network/init.py b/deepmd/pt/model/network/init.py new file mode 100644 index 0000000000..0bab6b66bd --- /dev/null +++ b/deepmd/pt/model/network/init.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import math +import warnings +from typing import Optional as _Optional + +import torch +from torch import ( + Tensor, +) + +# Copyright (c) 2024 The PyTorch Authors. All rights reserved. +# +# This file includes source code from PyTorch of version v2.3.0, which is released under the BSD-3-Clause license. +# For more information about PyTorch, visit https://pytorch.org/. + + +# These no_grad_* functions are necessary as wrappers around the parts of these +# functions that use `with torch.no_grad()`. The JIT doesn't support context +# managers, so these need to be implemented as builtins. Using these wrappers +# lets us keep those builtins small and re-usable. +def _no_grad_uniform_(tensor, a, b, generator=None): + with torch.no_grad(): + return tensor.uniform_(a, b, generator=generator) + + +def _no_grad_normal_(tensor, mean, std, generator=None): + with torch.no_grad(): + return tensor.normal_(mean, std, generator=generator) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def _no_grad_zero_(tensor): + with torch.no_grad(): + return tensor.zero_() + + +def _no_grad_fill_(tensor, val): + with torch.no_grad(): + return tensor.fill_(val) + + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + .. warning:: + In order to implement `Self-Normalizing Neural Networks`_ , + you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. + This gives the initial weights a variance of ``1 / N``, + which is necessary to induce a stable fixed point in the forward pass. + In contrast, the default gain for ``SELU`` sacrifices the normalization + effect for more stable gradient flow in rectangular layers. + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples + -------- + >>> gain = nn.init.calculate_gain( + ... "leaky_relu", 0.2 + ... ) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError(f"negative_slope {param} not a valid number") + return math.sqrt(2.0 / (1 + negative_slope**2)) + elif nonlinearity == "selu": + return ( + 3.0 / 4 + ) # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) + else: + raise ValueError(f"Unsupported nonlinearity {nonlinearity}") + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.size(1) + num_output_fmaps = tensor.size(0) + receptive_field_size = 1 + if tensor.dim() > 2: + # math.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + +def zeros_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `0`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.zeros_(w) + """ + return _no_grad_zero_(tensor) + + +def ones_(tensor: Tensor) -> Tensor: + r"""Fill the input Tensor with the scalar value `1`. + + Args: + tensor: an n-dimensional `torch.Tensor` + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.ones_(w) + """ + return _no_grad_fill_(tensor, 1.0) + + +def constant_(tensor: Tensor, val: float) -> Tensor: + r"""Fill the input Tensor with the value :math:`\text{val}`. + + Args: + tensor: an n-dimensional `torch.Tensor` + val: the value to fill the tensor with + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.constant_(w, 0.3) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + constant_, (tensor,), tensor=tensor, val=val + ) + return _no_grad_fill_(tensor, val) + + +def normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from the normal distribution. + + :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.normal_(w) + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator + ) + return _no_grad_normal_(tensor, mean, std, generator) + + +def trunc_normal_( + tensor: Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input Tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) + + +def kaiming_uniform_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu") + """ + if torch.overrides.has_torch_function_variadic(tensor): + return torch.overrides.handle_torch_function( + kaiming_uniform_, + (tensor,), + tensor=tensor, + a=a, + mode=mode, + nonlinearity=nonlinearity, + generator=generator, + ) + + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + with torch.no_grad(): + return tensor.uniform_(-bound, bound, generator=generator) + + +def kaiming_normal_( + tensor: Tensor, + a: float = 0, + mode: str = "fan_in", + nonlinearity: str = "leaky_relu", + generator: _Optional[torch.Generator] = None, +): + r"""Fill the input `Tensor` with values using a Kaiming normal distribution. + + The method is described in `Delving deep into rectifiers: Surpassing + human-level performance on ImageNet classification` - He, K. et al. (2015). + The resulting tensor will have values sampled from + :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} + + Also known as He initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu") + """ + if 0 in tensor.shape: + warnings.warn("Initializing zero-element tensors is a no-op") + return tensor + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with torch.no_grad(): + return tensor.normal_(0, std, generator=generator) + + +def xavier_uniform_( + tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier uniform distribution. + + The method is described in `Understanding the difficulty of training + deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). + The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu")) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a, generator) + + +def xavier_normal_( + tensor: Tensor, + gain: float = 1.0, + generator: _Optional[torch.Generator] = None, +) -> Tensor: + r"""Fill the input `Tensor` with values using a Xavier normal distribution. + + The method is described in `Understanding the difficulty of training deep feedforward + neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor + will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + generator: the torch Generator to sample from (default: None) + + Examples + -------- + >>> w = torch.empty(3, 5) + >>> nn.init.xavier_normal_(w) + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + + return _no_grad_normal_(tensor, 0.0, std, generator) diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py index 7c58e248ba..385bbaf270 100644 --- a/deepmd/pt/model/network/layernorm.py +++ b/deepmd/pt/model/network/layernorm.py @@ -1,9 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Optional, +) + import numpy as np import torch import torch.nn as nn from deepmd.dpmodel.utils.network import LayerNorm as DPLayerNorm +from deepmd.pt.model.network.init import ( + normal_, + ones_, + zeros_, +) from deepmd.pt.utils import ( env, ) @@ -12,6 +21,7 @@ PRECISION_DICT, ) from deepmd.pt.utils.utils import ( + get_generator, to_numpy_array, to_torch_tensor, ) @@ -33,6 +43,7 @@ def __init__( stddev: float = 1.0, precision: str = DEFAULT_PRECISION, trainable: bool = True, + seed: Optional[int] = None, ): super().__init__() self.eps = eps @@ -44,12 +55,17 @@ def __init__( self.bias = nn.Parameter( data=empty_t([num_in], self.prec), ) + random_generator = get_generator(seed) if self.uni_init: - nn.init.ones_(self.matrix.data) - nn.init.zeros_(self.bias.data) + ones_(self.matrix.data) + zeros_(self.bias.data) else: - nn.init.normal_(self.bias.data, mean=bavg, std=stddev) - nn.init.normal_(self.matrix.data, std=stddev / np.sqrt(self.num_in)) + normal_(self.bias.data, mean=bavg, std=stddev, generator=random_generator) + normal_( + self.matrix.data, + std=stddev / np.sqrt(self.num_in), + generator=random_generator, + ) self.trainable = trainable if not self.trainable: self.matrix.requires_grad = False diff --git a/deepmd/pt/model/network/mlp.py b/deepmd/pt/model/network/mlp.py index 5bd1fb0484..e5ea339fb7 100644 --- a/deepmd/pt/model/network/mlp.py +++ b/deepmd/pt/model/network/mlp.py @@ -24,12 +24,19 @@ make_fitting_network, make_multilayer_network, ) +from deepmd.pt.model.network.init import ( + kaiming_normal_, + normal_, + trunc_normal_, + xavier_uniform_, +) from deepmd.pt.utils.env import ( DEFAULT_PRECISION, PRECISION_DICT, ) from deepmd.pt.utils.utils import ( ActivationFn, + get_generator, to_numpy_array, to_torch_tensor, ) @@ -79,6 +86,7 @@ def __init__( stddev: float = 1.0, precision: str = DEFAULT_PRECISION, init: str = "default", + seed: Optional[int] = None, ): super().__init__() # only use_timestep when skip connection is established. @@ -92,6 +100,7 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.matrix = nn.Parameter(data=empty_t((num_in, num_out), self.prec)) + random_generator = get_generator(seed) if bias: self.bias = nn.Parameter( data=empty_t([num_out], self.prec), @@ -104,17 +113,19 @@ def __init__( self.idt = None self.resnet = resnet if init == "default": - self._default_normal_init(bavg=bavg, stddev=stddev) + self._default_normal_init( + bavg=bavg, stddev=stddev, generator=random_generator + ) elif init == "trunc_normal": - self._trunc_normal_init(1.0) + self._trunc_normal_init(1.0, generator=random_generator) elif init == "relu": - self._trunc_normal_init(2.0) + self._trunc_normal_init(2.0, generator=random_generator) elif init == "glorot": - self._glorot_uniform_init() + self._glorot_uniform_init(generator=random_generator) elif init == "gating": self._zero_init(self.use_bias) elif init == "kaiming_normal": - self._normal_init() + self._normal_init(generator=random_generator) elif init == "final": self._zero_init(False) else: @@ -138,25 +149,34 @@ def dim_in(self) -> int: def dim_out(self) -> int: return self.matrix.shape[1] - def _default_normal_init(self, bavg: float = 0.0, stddev: float = 1.0): - nn.init.normal_( - self.matrix.data, std=stddev / np.sqrt(self.num_out + self.num_in) + def _default_normal_init( + self, + bavg: float = 0.0, + stddev: float = 1.0, + generator: Optional[torch.Generator] = None, + ): + normal_( + self.matrix.data, + std=stddev / np.sqrt(self.num_out + self.num_in), + generator=generator, ) if self.bias is not None: - nn.init.normal_(self.bias.data, mean=bavg, std=stddev) + normal_(self.bias.data, mean=bavg, std=stddev, generator=generator) if self.idt is not None: - nn.init.normal_(self.idt.data, mean=0.1, std=0.001) + normal_(self.idt.data, mean=0.1, std=0.001, generator=generator) - def _trunc_normal_init(self, scale=1.0): + def _trunc_normal_init( + self, scale=1.0, generator: Optional[torch.Generator] = None + ): # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 _, fan_in = self.matrix.shape scale = scale / max(1, fan_in) std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR - nn.init.trunc_normal_(self.matrix, mean=0.0, std=std) + trunc_normal_(self.matrix, mean=0.0, std=std, generator=generator) - def _glorot_uniform_init(self): - nn.init.xavier_uniform_(self.matrix, gain=1) + def _glorot_uniform_init(self, generator: Optional[torch.Generator] = None): + xavier_uniform_(self.matrix, gain=1, generator=generator) def _zero_init(self, use_bias=True): with torch.no_grad(): @@ -165,8 +185,8 @@ def _zero_init(self, use_bias=True): with torch.no_grad(): self.bias.fill_(1.0) - def _normal_init(self): - nn.init.kaiming_normal_(self.matrix, nonlinearity="linear") + def _normal_init(self, generator: Optional[torch.Generator] = None): + kaiming_normal_(self.matrix, nonlinearity="linear", generator=generator) def forward( self, diff --git a/deepmd/pt/model/network/network.py b/deepmd/pt/model/network/network.py index e5f76368bc..c2a719c2b0 100644 --- a/deepmd/pt/model/network/network.py +++ b/deepmd/pt/model/network/network.py @@ -567,6 +567,7 @@ def __init__( bavg=0.0, stddev=1.0, precision="default", + seed: Optional[int] = None, use_econf_tebd=False, type_map=None, ): @@ -586,6 +587,7 @@ def __init__( use_econf_tebd=use_econf_tebd, type_map=type_map, precision=precision, + seed=seed, ) # nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev) @@ -699,13 +701,13 @@ def __init__( ) ) embed_input_dim = ECONF_DIM - # no way to pass seed? self.embedding_net = EmbeddingNet( embed_input_dim, self.neuron, self.activation_function, self.resnet_dt, self.precision, + self.seed, ) for param in self.parameters(): param.requires_grad = trainable diff --git a/deepmd/pt/model/task/ener.py b/deepmd/pt/model/task/ener.py index 12c0917dd2..ea9e21b1ae 100644 --- a/deepmd/pt/model/task/ener.py +++ b/deepmd/pt/model/task/ener.py @@ -55,6 +55,7 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, + seed: Optional[int] = None, **kwargs, ): super().__init__( @@ -70,6 +71,7 @@ def __init__( activation_function=activation_function, precision=precision, mixed_types=mixed_types, + seed=seed, **kwargs, ) @@ -153,9 +155,6 @@ def __init__( filter_layers.append(one) self.filter_layers = torch.nn.ModuleList(filter_layers) - if "seed" in kwargs: - torch.manual_seed(kwargs["seed"]) - def output_def(self): return FittingOutputDef( [ diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index 00579b957f..73390aebc9 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -156,6 +156,7 @@ def __init__( self.precision = precision self.prec = PRECISION_DICT[self.precision] self.rcond = rcond + self.seed = seed # order matters, should be place after the assignment of ntypes self.reinit_exclude(exclude_types) self.trainable = trainable @@ -229,14 +230,12 @@ def __init__( self.resnet_dt, self.precision, bias_out=True, + seed=seed, ) for ii in range(self.ntypes if not self.mixed_types else 1) ], ) self.filter_layers_old = None - - if seed is not None: - torch.manual_seed(seed) # set trainable for param in self.parameters(): param.requires_grad = self.trainable diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 4056b30d87..cceadb38d2 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -333,6 +333,8 @@ def get_loss(loss_params, start_lr, _ntypes, _model): # Model dp_random.seed(training_params["seed"]) + if training_params["seed"] is not None: + torch.manual_seed(training_params["seed"]) if not self.multi_task: self.model = get_single_model( model_params, diff --git a/deepmd/pt/utils/utils.py b/deepmd/pt/utils/utils.py index d1ef089e49..6b4377038f 100644 --- a/deepmd/pt/utils/utils.py +++ b/deepmd/pt/utils/utils.py @@ -109,3 +109,12 @@ def dict_to_device(sample_dict): else: if sample_dict[key] is not None: sample_dict[key] = sample_dict[key].to(DEVICE) + + +def get_generator(seed: Optional[int] = None) -> Optional[torch.Generator]: + if seed is not None: + generator = torch.Generator(device=DEVICE) + generator.manual_seed(seed) + return generator + else: + return None diff --git a/source/tests/pt/model/test_forward_lower.py b/source/tests/pt/model/test_forward_lower.py index 5d2bfe599e..f700f5ad35 100644 --- a/source/tests/pt/model/test_forward_lower.py +++ b/source/tests/pt/model/test_forward_lower.py @@ -167,6 +167,8 @@ def setUp(self): self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa1)["descriptor"] + # double sel for virtual atoms to avoid large error + model_params["descriptor"]["sel"] *= 2 self.test_spin = True self.model = get_model(model_params).to(env.DEVICE) @@ -176,6 +178,9 @@ def setUp(self): self.prec = 1e-10 model_params = copy.deepcopy(model_spin) model_params["descriptor"] = copy.deepcopy(model_dpa2)["descriptor"] + # double sel for virtual atoms to avoid large error + model_params["descriptor"]["repinit"]["nsel"] *= 2 + model_params["descriptor"]["repformer"]["nsel"] *= 2 self.test_spin = True self.model = get_model(model_params).to(env.DEVICE)