From 8355947ef052fd3806a361e25bbcff18f3e2f627 Mon Sep 17 00:00:00 2001 From: Yan Wang <116817801+cherryWangY@users.noreply.github.com> Date: Fri, 1 Nov 2024 19:41:59 +0800 Subject: [PATCH] Add 4 pt descriptor compression (#4227) se_a, se_atten(DPA1), se_t, se_r ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a model compression feature across multiple descriptor classes, enhancing performance and efficiency. - Added `enable_compression` methods to various classes, allowing users to enable and configure compression settings. - **Bug Fixes** - Improved error handling for unsupported compression scenarios and parameter validation. - **Tests** - Added comprehensive unit tests for new compression functionalities across multiple descriptor classes to ensure accuracy and reliability. - **Documentation** - Enhanced documentation for new methods and classes to clarify usage and parameters related to compression. --------- Signed-off-by: Jinzhe Zeng Signed-off-by: Yan Wang <116817801+cherryWangY@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- .../descriptor/make_base_descriptor.py | 25 + deepmd/pt/model/descriptor/dpa1.py | 87 +++ deepmd/pt/model/descriptor/se_a.py | 134 +++- deepmd/pt/model/descriptor/se_atten.py | 163 ++++- deepmd/pt/model/descriptor/se_r.py | 116 +++- deepmd/pt/model/descriptor/se_t.py | 131 +++- deepmd/pt/utils/tabulate.py | 607 ++++++++++++++++++ deepmd/tf/utils/tabulate.py | 369 +---------- deepmd/utils/tabulate.py | 458 +++++++++++++ source/op/pt/tabulate_multi_device.cc | 8 +- .../model/test_compressed_descriptor_se_a.py | 132 ++++ .../test_compressed_descriptor_se_atten.py | 142 ++++ .../model/test_compressed_descriptor_se_r.py | 129 ++++ .../model/test_compressed_descriptor_se_t.py | 129 ++++ source/tests/pt/test_tabulate.py | 135 ++++ 15 files changed, 2377 insertions(+), 388 deletions(-) create mode 100644 deepmd/pt/utils/tabulate.py create mode 100644 deepmd/utils/tabulate.py create mode 100644 source/tests/pt/model/test_compressed_descriptor_se_a.py create mode 100644 source/tests/pt/model/test_compressed_descriptor_se_atten.py create mode 100644 source/tests/pt/model/test_compressed_descriptor_se_r.py create mode 100644 source/tests/pt/model/test_compressed_descriptor_se_t.py create mode 100644 source/tests/pt/test_tabulate.py diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index b9c1e93387..9f2891d8c0 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -147,6 +147,31 @@ def compute_input_stats( """Update mean and stddev for descriptor elements.""" raise NotImplementedError + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + raise NotImplementedError("This descriptor doesn't support compression!") + @abstractmethod def fwd( self, diff --git a/deepmd/pt/model/descriptor/dpa1.py b/deepmd/pt/model/descriptor/dpa1.py index d3156f7c84..76115b2810 100644 --- a/deepmd/pt/model/descriptor/dpa1.py +++ b/deepmd/pt/model/descriptor/dpa1.py @@ -24,9 +24,15 @@ from deepmd.pt.utils.env import ( RESERVED_PRECISON_DICT, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -261,6 +267,8 @@ def __init__( if ln_eps is None: ln_eps = 1e-5 + self.tebd_input_mode = tebd_input_mode + del type, spin, attn_mask self.se_atten = DescrptBlockSeAtten( rcut, @@ -293,6 +301,7 @@ def __init__( self.use_econf_tebd = use_econf_tebd self.use_tebd_bias = use_tebd_bias self.type_map = type_map + self.compress = False self.type_embedding = TypeEmbedNet( ntypes, tebd_dim, @@ -551,6 +560,84 @@ def t_cvt(xx): ) return obj + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + # do some checks before the mocel compression process + if self.compress: + raise ValueError("Compression is already enabled.") + assert ( + not self.se_atten.resnet_dt + ), "Model compression error: descriptor resnet_dt must be false!" + for tt in self.se_atten.exclude_types: + if (tt[0] not in range(self.se_atten.ntypes)) or ( + tt[1] not in range(self.se_atten.ntypes) + ): + raise RuntimeError( + "exclude types" + + str(tt) + + " must within the number of atomic types " + + str(self.se_atten.ntypes) + + "!" + ) + if ( + self.se_atten.ntypes * self.se_atten.ntypes + - len(self.se_atten.exclude_types) + == 0 + ): + raise RuntimeError( + "Empty embedding-nets are not supported in model compression!" + ) + + if self.se_atten.attn_layer != 0: + raise RuntimeError("Cannot compress model when attention layer is not 0.") + + if self.tebd_input_mode != "strip": + raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'") + + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + + self.se_atten.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 56cb1f5bc6..630b96ce9b 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -58,11 +58,34 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from .base_descriptor import ( BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"): + + def tabulate_fusion_se_a( + argument0, + argument1, + argument2, + argument3, + argument4, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a + @BaseDescriptor.register("se_e2_a") @BaseDescriptor.register("se_a") @@ -93,6 +116,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.compress = False self.sea = DescrptBlockSeA( rcut, rcut_smth, @@ -225,6 +249,53 @@ def reinit_exclude( """Update the type exclusions.""" self.sea.reinit_exclude(exclude_types) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.sea.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def forward( self, coord_ext: torch.Tensor, @@ -366,6 +437,10 @@ def update_sel( class DescrptBlockSeA(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] def __init__( self, @@ -425,6 +500,13 @@ def __init__( self.register_buffer("mean", mean) self.register_buffer("stddev", stddev) + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] + ndim = 1 if self.type_one_side else 2 filter_layers = NetworkCollection( ndim=ndim, ntypes=len(sel), network_type="embedding_network" @@ -443,6 +525,7 @@ def __init__( self.filter_layers = filter_layers self.stats = None # set trainable + self.trainable = trainable for param in self.parameters(): param.requires_grad = trainable @@ -470,6 +553,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + def get_dim_emb(self) -> int: """Returns the output dimension.""" return self.neuron[-1] @@ -578,6 +665,19 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -627,6 +727,7 @@ def forward( for embedding_idx, ll in enumerate(self.filter_layers.networks): if self.type_one_side: ii = embedding_idx + ti = -1 # torch.jit is not happy with slice(None) # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) # applying a mask seems to cause performance degradation @@ -648,10 +749,35 @@ def forward( rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] rr = rr * mm[:, :, None] ss = rr[:, :, :1] - # nfnl x nt x ng - gg = ll.forward(ss) - # nfnl x 4 x ng - gr = torch.matmul(rr.permute(0, 2, 1), gg) + + if self.compress: + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + else: + net = "filter_" + str(ti) + "_net_" + str(ii) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf + tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) + gr = torch.ops.deepmd.tabulate_fusion_se_a( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ss.contiguous(), + rr.contiguous(), + self.filter_neuron[-1], + )[0] + else: + # nfnl x nt x ng + gg = ll.forward(ss) + # nfnl x 4 x ng + gr = torch.matmul(rr.permute(0, 2, 1), gg) + if ti_mask is not None: xyz_scatter[ti_mask] += gr else: diff --git a/deepmd/pt/model/descriptor/se_atten.py b/deepmd/pt/model/descriptor/se_atten.py index aab72f7e98..8c56ccf827 100644 --- a/deepmd/pt/model/descriptor/se_atten.py +++ b/deepmd/pt/model/descriptor/se_atten.py @@ -49,9 +49,33 @@ check_version_compatibility, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_atten"): + + def tabulate_fusion_se_atten( + argument0, + argument1, + argument2, + argument3, + argument4, + argument5, + argument6, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_atten is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_atten = tabulate_fusion_se_atten + @DescriptorBlock.register("se_atten") class DescrptBlockSeAtten(DescriptorBlock): + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] + def __init__( self, rcut: float, @@ -178,6 +202,14 @@ def __init__( ln_eps = 1e-5 self.ln_eps = ln_eps + # add for compression + self.compress = False + self.is_sorted = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] + if isinstance(sel, int): sel = [sel] @@ -189,6 +221,7 @@ def __init__( self.ndescrpt = self.nnei * 4 # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) + self.dpa1_attention = NeighborGatedAttention( self.attn_layer, self.nnei, @@ -277,6 +310,10 @@ def get_dim_out(self) -> int: """Returns the output dimension.""" return self.dim_out + def get_dim_rot_mat_1(self) -> int: + """Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3.""" + return self.filter_neuron[-1] + def get_dim_emb(self) -> int: """Returns the output dimension of embedding.""" return self.filter_neuron[-1] @@ -384,8 +421,22 @@ def reinit_exclude( exclude_types: list[tuple[int, int]] = [], ): self.exclude_types = exclude_types + self.is_sorted = len(self.exclude_types) == 0 self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -450,20 +501,21 @@ def forward( sw = torch.squeeze(sw, -1) # nf x nloc x nt -> nf x nloc x nnei x nt atype_tebd = extended_atype_embd[:, :nloc, :] - atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) + atype_tebd_nnei = atype_tebd.unsqueeze(2).expand(-1, -1, self.nnei, -1) # i # nf x nall x nt nt = extended_atype_embd.shape[-1] atype_tebd_ext = extended_atype_embd # nb x (nloc x nnei) x nt index = nlist.reshape(nb, nloc * nnei).unsqueeze(-1).expand(-1, -1, nt) # nb x (nloc x nnei) x nt - atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) + atype_tebd_nlist = torch.gather(atype_tebd_ext, dim=1, index=index) # j # nb x nloc x nnei x nt atype_tebd_nlist = atype_tebd_nlist.view(nb, nloc, nnei, nt) # beyond the cutoff sw should be 0.0 sw = sw.masked_fill(~nlist_mask, 0.0) # (nb x nloc) x nnei exclude_mask = exclude_mask.view(nb * nloc, nnei) + # nfnl x nnei x 4 dmatrix = dmatrix.view(-1, self.nnei, 4) nfnl = dmatrix.shape[0] @@ -482,33 +534,91 @@ def forward( ss = torch.concat([ss, nlist_tebd], dim=2) # nfnl x nnei x ng gg = self.filter_layers.networks[0](ss) + input_r = torch.nn.functional.normalize( + rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) elif self.tebd_input_mode in ["strip"]: - # nfnl x nnei x ng - gg_s = self.filter_layers.networks[0](ss) - assert self.filter_layers_strip is not None - if not self.type_one_side: - # nfnl x nnei x (tebd_dim * 2) - tt = torch.concat([nlist_tebd, atype_tebd], dim=2) + if self.compress: + net = "filter_net" + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + ss = ss.reshape(-1, 1) + # nfnl x nnei x ng + # gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + # gg = gg_s * gg_t + gg_s + tensor_data = self.table_data[net].to(gg_t.device).to(dtype=self.prec) + info_tensor = torch.tensor(info, dtype=self.prec, device="cpu") + gg_t = gg_t.reshape(-1, gg_t.size(-1)) + # Convert all tensors to the required precision at once + ss, rr, gg_t = (t.to(self.prec) for t in (ss, rr, gg_t)) + xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_atten( + tensor_data.contiguous(), + info_tensor.contiguous(), + ss.contiguous(), + rr.contiguous(), + gg_t.contiguous(), + self.filter_neuron[-1], + self.is_sorted, + )[0] + # to make torchscript happy + gg = torch.empty( + nframes, + nloc, + self.nnei, + self.filter_neuron[-1], + dtype=gg_t.dtype, + device=gg_t.device, + ) else: - # nfnl x nnei x tebd_dim - tt = nlist_tebd - # nfnl x nnei x ng - gg_t = self.filter_layers_strip.networks[0](tt) - if self.smooth: - gg_t = gg_t * sw.reshape(-1, self.nnei, 1) - # nfnl x nnei x ng - gg = gg_s * gg_t + gg_s + # nfnl x nnei x ng + gg_s = self.filter_layers.networks[0](ss) + assert self.filter_layers_strip is not None + if not self.type_one_side: + # nfnl x nnei x (tebd_dim * 2) + tt = torch.concat([nlist_tebd, atype_tebd], dim=2) # dynamic, index + else: + # nfnl x nnei x tebd_dim + tt = nlist_tebd + # nfnl x nnei x ng + gg_t = self.filter_layers_strip.networks[0](tt) + if self.smooth: + gg_t = gg_t * sw.reshape(-1, self.nnei, 1) + # nfnl x nnei x ng + gg = gg_s * gg_t + gg_s + input_r = torch.nn.functional.normalize( + rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 + ) + gg = self.dpa1_attention( + gg, nlist_mask, input_r=input_r, sw=sw + ) # shape is [nframes*nloc, self.neei, out_size] + # nfnl x 4 x ng + xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) else: raise NotImplementedError - input_r = torch.nn.functional.normalize( - rr.reshape(-1, self.nnei, 4)[:, :, 1:4], dim=-1 - ) - gg = self.dpa1_attention( - gg, nlist_mask, input_r=input_r, sw=sw - ) # shape is [nframes*nloc, self.neei, out_size] - # nfnl x 4 x ng - xyz_scatter = torch.matmul(rr.permute(0, 2, 1), gg) xyz_scatter = xyz_scatter / self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1) rot_mat = xyz_scatter_1[:, :, 1:4] @@ -516,9 +626,12 @@ def forward( result = torch.matmul( xyz_scatter_1, xyz_scatter_2 ) # shape is [nframes*nloc, self.filter_neuron[-1], self.axis_neuron] + return ( result.view(nframes, nloc, self.filter_neuron[-1] * self.axis_neuron), - gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]), + gg.view(nframes, nloc, self.nnei, self.filter_neuron[-1]) + if not self.compress + else None, dmatrix.view(nframes, nloc, self.nnei, 4)[..., 1:], rot_mat.view(nframes, nloc, self.filter_neuron[-1], 3), sw, diff --git a/deepmd/pt/model/descriptor/se_r.py b/deepmd/pt/model/descriptor/se_r.py index 0aa50c613f..4a74b7671f 100644 --- a/deepmd/pt/model/descriptor/se_r.py +++ b/deepmd/pt/model/descriptor/se_r.py @@ -32,9 +32,15 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) from deepmd.pt.utils.update_sel import ( UpdateSel, ) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from deepmd.utils.data_system import ( DeepmdDataSystem, ) @@ -52,10 +58,31 @@ BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_r"): + + def tabulate_fusion_se_r( + argument0, + argument1, + argument2, + argument3, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_r is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_r = tabulate_fusion_se_r + @BaseDescriptor.register("se_e2_r") @BaseDescriptor.register("se_r") class DescrptSeR(BaseDescriptor, torch.nn.Module): + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] + def __init__( self, rcut, @@ -90,6 +117,12 @@ def __init__( # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] self.sel = sel self.sec = torch.tensor( @@ -123,6 +156,7 @@ def __init__( self.filter_layers = filter_layers self.stats = None # set trainable + self.trainable = trainable for param in self.parameters(): param.requires_grad = trainable @@ -313,6 +347,51 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + data["type_one_side"], + data["exclude_types"], + ActivationFn(data["activation_function"]), + ) + self.table_config = [ + table_extrapolate, + table_stride_1, + table_stride_2, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2 + ) + self.table_data = self.table.data + self.compress = True + def forward( self, coord_ext: torch.Tensor, @@ -353,7 +432,7 @@ def forward( The smooth switch function. """ - del mapping + del mapping, comm_dict nf = nlist.shape[0] nloc = nlist.shape[1] atype = atype_ext[:, :nloc] @@ -380,19 +459,44 @@ def forward( # nfnl x nnei exclude_mask = self.emask(nlist, atype_ext).view(nfnl, self.nnei) + xyz_scatter_total = [] for ii, ll in enumerate(self.filter_layers.networks): # nfnl x nt mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 1 ss = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] ss = ss * mm[:, :, None] - # nfnl x nt x ng - gg = ll.forward(ss) - gg = torch.mean(gg, dim=1).unsqueeze(1) - xyz_scatter += gg * (self.sel[ii] / self.nnei) + if self.compress: + ss = ss.squeeze(-1) + net = "filter_-1_net_" + str(ii) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec) + xyz_scatter = torch.ops.deepmd.tabulate_fusion_se_r( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ss, + self.filter_neuron[-1], + )[0] + xyz_scatter_total.append(xyz_scatter) + else: + # nfnl x nt x ng + gg = ll.forward(ss) + gg = torch.mean(gg, dim=1).unsqueeze(1) + xyz_scatter += gg * (self.sel[ii] / self.nnei) res_rescale = 1.0 / 5.0 - result = xyz_scatter * res_rescale + if self.compress: + xyz_scatter = torch.cat(xyz_scatter_total, dim=1) + result = torch.mean(xyz_scatter, dim=1) * res_rescale + else: + result = xyz_scatter * res_rescale result = result.view(nf, nloc, self.filter_neuron[-1]) return ( result.to(dtype=env.GLOBAL_PT_FLOAT_PRECISION), diff --git a/deepmd/pt/model/descriptor/se_t.py b/deepmd/pt/model/descriptor/se_t.py index 7b83bcbd69..5a634d7549 100644 --- a/deepmd/pt/model/descriptor/se_t.py +++ b/deepmd/pt/model/descriptor/se_t.py @@ -58,11 +58,34 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.tabulate import ( + DPTabulate, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) from .base_descriptor import ( BaseDescriptor, ) +if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_t"): + + def tabulate_fusion_se_t( + argument0, + argument1, + argument2, + argument3, + argument4, + ) -> list[torch.Tensor]: + raise NotImplementedError( + "tabulate_fusion_se_t is not available since customized PyTorch OP library is not built when freezing the model. " + "See documentation for model compression for details." + ) + + # Note: this hack cannot actually save a model that can be runned using LAMMPS. + torch.ops.deepmd.tabulate_fusion_se_t = tabulate_fusion_se_t + @BaseDescriptor.register("se_e3") @BaseDescriptor.register("se_at") @@ -129,6 +152,7 @@ def __init__( raise NotImplementedError("old implementation of spin is not supported.") super().__init__() self.type_map = type_map + self.compress = False self.seat = DescrptBlockSeT( rcut, rcut_smth, @@ -252,6 +276,54 @@ def compute_input_stats( """ return self.seat.compute_input_stats(merged, path) + def enable_compression( + self, + min_nbor_dist: float, + table_extrapolate: float = 5, + table_stride_1: float = 0.01, + table_stride_2: float = 0.1, + check_frequency: int = -1, + ) -> None: + """Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data. + + Parameters + ---------- + min_nbor_dist + The nearest distance between atoms + table_extrapolate + The scale of model extrapolation + table_stride_1 + The uniform stride of the first table + table_stride_2 + The uniform stride of the second table + check_frequency + The overflow check frequency + """ + if self.compress: + raise ValueError("Compression is already enabled.") + data = self.serialize() + self.table = DPTabulate( + self, + data["neuron"], + exclude_types=data["exclude_types"], + activation_fn=ActivationFn(data["activation_function"]), + ) + stride_1_scaled = table_stride_1 * 10 + stride_2_scaled = table_stride_2 * 10 + self.table_config = [ + table_extrapolate, + stride_1_scaled, + stride_2_scaled, + check_frequency, + ] + self.lower, self.upper = self.table.build( + min_nbor_dist, table_extrapolate, stride_1_scaled, stride_2_scaled + ) + self.seat.enable_compression( + self.table.data, self.table_config, self.lower, self.upper + ) + self.compress = True + def reinit_exclude( self, exclude_types: list[tuple[int, int]] = [], @@ -396,6 +468,10 @@ def update_sel( class DescrptBlockSeT(DescriptorBlock): ndescrpt: Final[int] __constants__: ClassVar[list] = ["ndescrpt"] + lower: dict[str, int] + upper: dict[str, int] + table_data: dict[str, torch.Tensor] + table_config: list[Union[int, float]] def __init__( self, @@ -467,6 +543,12 @@ def __init__( self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4 + # add for compression + self.compress = False + self.lower = {} + self.upper = {} + self.table_data = {} + self.table_config = [] wanted_shape = (self.ntypes, self.nnei, 4) mean = torch.zeros(wanted_shape, dtype=self.prec, device=env.DEVICE) @@ -628,6 +710,19 @@ def reinit_exclude( self.exclude_types = exclude_types self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types) + def enable_compression( + self, + table_data, + table_config, + lower, + upper, + ) -> None: + self.compress = True + self.table_data = table_data + self.table_config = table_config + self.lower = lower + self.upper = upper + def forward( self, nlist: torch.Tensor, @@ -711,12 +806,36 @@ def forward( rr_j = rr_j * mm_j[:, :, None] # nfnl x nt_i x nt_j env_ij = torch.einsum("ijm,ikm->ijk", rr_i, rr_j) - # nfnl x nt_i x nt_j x 1 - env_ij_reshape = env_ij.unsqueeze(-1) - # nfnl x nt_i x nt_j x ng - gg = ll.forward(env_ij_reshape) - # nfnl x nt_i x nt_j x ng - res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg) + if self.compress: + ebd_env_ij = env_ij.view(-1, 1) + net = "filter_" + str(ti) + "_net_" + str(tj) + info = [ + self.lower[net], + self.upper[net], + self.upper[net] * self.table_config[0], + self.table_config[1], + self.table_config[2], + self.table_config[3], + ] + tensor_data = ( + self.table_data[net].to(env_ij.device).to(dtype=self.prec) + ) + ebd_env_ij = ebd_env_ij.to(dtype=self.prec) + env_ij = env_ij.to(dtype=self.prec) + res_ij = torch.ops.deepmd.tabulate_fusion_se_t( + tensor_data.contiguous(), + torch.tensor(info, dtype=self.prec, device="cpu").contiguous(), + ebd_env_ij.contiguous(), + env_ij.contiguous(), + self.filter_neuron[-1], + )[0] + else: + # nfnl x nt_i x nt_j x 1 + env_ij_reshape = env_ij.unsqueeze(-1) + # nfnl x nt_i x nt_j x ng + gg = ll.forward(env_ij_reshape) + # nfnl x nt_i x nt_j x ng + res_ij = torch.einsum("ijk,ijkm->im", env_ij, gg) res_ij = res_ij * (1.0 / float(nei_type_i) / float(nei_type_j)) result += res_ij # xyz_scatter /= (self.nnei * self.nnei) diff --git a/deepmd/pt/utils/tabulate.py b/deepmd/pt/utils/tabulate.py new file mode 100644 index 0000000000..7394ac082d --- /dev/null +++ b/deepmd/pt/utils/tabulate.py @@ -0,0 +1,607 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from functools import ( + cached_property, +) + +import numpy as np +import torch + +import deepmd +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + ActivationFn, +) +from deepmd.utils.tabulate import ( + BaseTabulate, +) + +log = logging.getLogger(__name__) + +SQRT_2_PI = np.sqrt(2 / np.pi) +GGELU = 0.044715 + + +class DPTabulate(BaseTabulate): + r"""Class for tabulation. + + Compress a model, which including tabulating the embedding-net. + The table is composed of fifth-order polynomial coefficients and is assembled from two sub-tables. The first table takes the stride(parameter) as it's uniform stride, while the second table takes 10 * stride as it's uniform stride + The range of the first table is automatically detected by deepmd-kit, while the second table ranges from the first table's upper boundary(upper) to the extrapolate(parameter) * upper. + + Parameters + ---------- + descrpt + Descriptor of the original model + neuron + Number of neurons in each hidden layers of the embedding net :math:`\\mathcal{N}` + type_one_side + Try to build N_types tables. Otherwise, building N_types^2 tables + exclude_types : List[List[int]] + The excluded pairs of types which have no interaction with each other. + For example, `[[0, 1]]` means no interaction between type 0 and type 1. + activation_function + The activation function in the embedding net. Supported options are {"tanh","gelu"} in common.ActivationFn. + """ + + def __init__( + self, + descrpt, + neuron: list[int], + type_one_side: bool = False, + exclude_types: list[list[int]] = [], + activation_fn: ActivationFn = ActivationFn("tanh"), + ) -> None: + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + True, + ) + self.descrpt_type = self._get_descrpt_type() + + supported_descrpt_type = ( + "Atten", + "A", + "T", + "R", + ) + + if self.descrpt_type in supported_descrpt_type: + self.sel_a = self.descrpt.get_sel() + self.rcut = self.descrpt.get_rcut() + self.rcut_smth = self.descrpt.get_rcut_smth() + else: + raise RuntimeError("Unsupported descriptor") + + # functype + activation_map = { + "tanh": 1, + "gelu": 2, + "gelu_tf": 2, + "relu": 3, + "relu6": 4, + "softplus": 5, + "sigmoid": 6, + } + + activation = activation_fn.activation + if activation in activation_map: + self.functype = activation_map[activation] + else: + raise RuntimeError("Unknown activation function type!") + + self.activation_fn = activation_fn + self.davg = self.descrpt.serialize()["@variables"]["davg"] + self.dstd = self.descrpt.serialize()["@variables"]["dstd"] + self.ntypes = self.descrpt.get_ntypes() + + self.embedding_net_nodes = self.descrpt.serialize()["embeddings"]["networks"] + + self.layer_size = self._get_layer_size() + self.table_size = self._get_table_size() + + self.bias = self._get_bias() + self.matrix = self._get_matrix() + + self.data_type = self._get_data_type() + self.last_layer_size = self._get_last_layer_size() + + def _make_data(self, xx, idx): + """Generate tabulation data for the given input. + + Parameters + ---------- + xx : np.ndarray + Input values to tabulate + idx : int + Index for accessing the correct network parameters + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + Values, first derivatives, and second derivatives + """ + xx = torch.from_numpy(xx).view(-1, 1).to(env.DEVICE) + for layer in range(self.layer_size): + if layer == 0: + xbar = torch.matmul( + xx, + torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ), + ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ) + if self.neuron[0] == 1: + yy = ( + self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + xx + ) + dy = unaggregated_dy_dx_s( + yy - xx, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + torch.ones((1, 1), dtype=yy.dtype) # pylint: disable=no-explicit-device + dy2 = unaggregated_dy2_dx_s( + yy - xx, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + elif self.neuron[0] == 2: + tt, yy = self._layer_1( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy - tt, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + torch.ones((1, 2), dtype=yy.dtype) # pylint: disable=no-explicit-device + dy2 = unaggregated_dy2_dx_s( + yy - tt, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + yy = self._layer_0( + xx, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dy = unaggregated_dy_dx_s( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + dy2 = unaggregated_dy2_dx_s( + yy, + dy, + self.matrix["layer_" + str(layer + 1)][idx], + xbar, + self.functype, + ) + else: + ybar = torch.matmul( + yy, + torch.from_numpy(self.matrix["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ), + ) + torch.from_numpy(self.bias["layer_" + str(layer + 1)][idx]).to( + env.DEVICE + ) + if self.neuron[layer] == self.neuron[layer - 1]: + zz = ( + self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + + yy + ) + dz = unaggregated_dy_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - yy, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + elif self.neuron[layer] == 2 * self.neuron[layer - 1]: + tt, zz = self._layer_1( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz - tt, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + else: + zz = self._layer_0( + yy, + self.matrix["layer_" + str(layer + 1)][idx], + self.bias["layer_" + str(layer + 1)][idx], + ) + dz = unaggregated_dy_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + ybar, + self.functype, + ) + dy2 = unaggregated_dy2_dx( + zz, + self.matrix["layer_" + str(layer + 1)][idx], + dy, + dy2, + ybar, + self.functype, + ) + dy = dz + yy = zz + + vv = zz.detach().cpu().numpy().astype(self.data_type) + dd = dy.detach().cpu().numpy().astype(self.data_type) + d2 = dy2.detach().cpu().numpy().astype(self.data_type) + return vv, dd, d2 + + def _layer_0(self, x, w, b): + w = torch.from_numpy(w).to(env.DEVICE) + b = torch.from_numpy(b).to(env.DEVICE) + return self.activation_fn(torch.matmul(x, w) + b) + + def _layer_1(self, x, w, b): + w = torch.from_numpy(w).to(env.DEVICE) + b = torch.from_numpy(b).to(env.DEVICE) + t = torch.cat([x, x], dim=1) + return t, self.activation_fn(torch.matmul(x, w) + b) + t + + def _get_descrpt_type(self): + if isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptDPA1): + return "Atten" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeA): + return "A" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeR): + return "R" + elif isinstance(self.descrpt, deepmd.pt.model.descriptor.DescrptSeT): + return "T" + raise RuntimeError(f"Unsupported descriptor {self.descrpt}") + + def _get_layer_size(self): + # get the number of layers in EmbeddingNet + layer_size = 0 + basic_size = 0 + if self.type_one_side: + basic_size = len(self.embedding_net_nodes) * len(self.neuron) + else: + basic_size = ( + len(self.embedding_net_nodes) + * len(self.embedding_net_nodes[0]) + * len(self.neuron) + ) + if self.descrpt_type == "Atten": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + elif self.descrpt_type == "A": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + elif self.descrpt_type == "T": + layer_size = len(self.embedding_net_nodes[0]["layers"]) + # layer_size = basic_size // int(comb(self.ntypes + 1, 2)) + elif self.descrpt_type == "R": + layer_size = basic_size // ( + self.ntypes * self.ntypes - len(self.exclude_types) + ) + if self.type_one_side: + layer_size = basic_size // (self.ntypes - self._n_all_excluded) + else: + raise RuntimeError("Unsupported descriptor") + return layer_size + + def _get_network_variable(self, var_name: str) -> dict: + """Get network variables (weights or biases) for all layers. + + Parameters + ---------- + var_name : str + Name of the variable to get ('w' for weights, 'b' for biases) + + Returns + ------- + dict + Dictionary mapping layer names to their variables + """ + result = {} + for layer in range(1, self.layer_size + 1): + result["layer_" + str(layer)] = [] + if self.descrpt_type == "Atten": + node = self.embedding_net_nodes[0]["layers"][layer - 1]["@variables"][ + var_name + ] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "A": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + elif self.descrpt_type == "T": + for ii in range(self.ntypes): + for jj in range(ii, self.ntypes): + node = self.embedding_net_nodes[jj * self.ntypes + ii][ + "layers" + ][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + elif self.descrpt_type == "R": + if self.type_one_side: + for ii in range(0, self.ntypes): + if not self._all_excluded(ii): + node = self.embedding_net_nodes[ii]["layers"][layer - 1][ + "@variables" + ][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + for ii in range(0, self.ntypes * self.ntypes): + if ( + ii // self.ntypes, + ii % self.ntypes, + ) not in self.exclude_types: + node = self.embedding_net_nodes[ + (ii % self.ntypes) * self.ntypes + ii // self.ntypes + ]["layers"][layer - 1]["@variables"][var_name] + result["layer_" + str(layer)].append(node) + else: + result["layer_" + str(layer)].append(np.array([])) + else: + raise RuntimeError("Unsupported descriptor") + return result + + def _get_bias(self): + return self._get_network_variable("b") + + def _get_matrix(self): + return self._get_network_variable("w") + + def _convert_numpy_to_tensor(self): + """Convert self.data from np.ndarray to torch.Tensor.""" + for ii in self.data: + self.data[ii] = torch.tensor(self.data[ii], device=env.DEVICE) # pylint: disable=no-explicit-dtype + + @cached_property + def _n_all_excluded(self) -> int: + """Then number of types excluding all types.""" + return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) + + +# customized op +def grad(xbar, y, functype): # functype=tanh, gelu, .. + if functype == 1: + return 1 - y * y + elif functype == 2: + var = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + return ( + 0.5 * SQRT_2_PI * xbar * (1 - var**2) * (3 * GGELU * xbar**2 + 1) + + 0.5 * var + + 0.5 + ) + elif functype == 3: + return 0.0 if xbar <= 0 else 1.0 + elif functype == 4: + return 0.0 if xbar <= 0 or xbar >= 6 else 1.0 + elif functype == 5: + return 1.0 - 1.0 / (1.0 + np.exp(xbar)) + elif functype == 6: + return y * (1 - y) + + raise ValueError(f"Unsupported function type: {functype}") + + +def grad_grad(xbar, y, functype): + if functype == 1: + return -2 * y * (1 - y * y) + elif functype == 2: + var1 = np.tanh(SQRT_2_PI * (xbar + GGELU * xbar**3)) + var2 = SQRT_2_PI * (1 - var1**2) * (3 * GGELU * xbar**2 + 1) + return ( + 3 * GGELU * SQRT_2_PI * xbar**2 * (1 - var1**2) + - SQRT_2_PI * xbar * var2 * (3 * GGELU * xbar**2 + 1) * var1 + + var2 + ) + elif functype in [3, 4]: + return 0 + elif functype == 5: + return np.exp(xbar) / ((1 + np.exp(xbar)) * (1 + np.exp(xbar))) + elif functype == 6: + return y * (1 - y) * (1 - 2 * y) + else: + return -1 + + +def unaggregated_dy_dx_s( + y: torch.Tensor, w_np: np.ndarray, xbar: torch.Tensor, functype: int +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if y.dim() != 2: + raise ValueError("Dim of input y should be 2") + if w.dim() != 2: + raise ValueError("Dim of input w should be 2") + if xbar.dim() != 2: + raise ValueError("Dim of input xbar should be 2") + + length, width = y.shape + dy_dx = torch.zeros_like(y) + w = torch.flatten(w) + + for ii in range(length): + for jj in range(width): + dy_dx[ii, jj] = grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] + + return dy_dx + + +def unaggregated_dy2_dx_s( + y: torch.Tensor, + dy: torch.Tensor, + w_np: np.ndarray, + xbar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if y.dim() != 2: + raise ValueError("Dim of input y should be 2") + if dy.dim() != 2: + raise ValueError("Dim of input dy should be 2") + if w.dim() != 2: + raise ValueError("Dim of input w should be 2") + if xbar.dim() != 2: + raise ValueError("Dim of input xbar should be 2") + + length, width = y.shape + dy2_dx = torch.zeros_like(y) + w = torch.flatten(w) + + for ii in range(length): + for jj in range(width): + dy2_dx[ii, jj] = ( + grad_grad(xbar[ii, jj], y[ii, jj], functype) * w[jj] * w[jj] + ) + + return dy2_dx + + +def unaggregated_dy_dx( + z: torch.Tensor, + w_np: np.ndarray, + dy_dx: torch.Tensor, + ybar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if z.dim() != 2: + raise ValueError("z tensor must have 2 dimensions") + if w.dim() != 2: + raise ValueError("w tensor must have 2 dimensions") + if dy_dx.dim() != 2: + raise ValueError("dy_dx tensor must have 2 dimensions") + if ybar.dim() != 2: + raise ValueError("ybar tensor must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + dy_dx = torch.flatten(dy_dx) + + dz_dx = torch.zeros_like(z) + + for kk in range(length): + for ii in range(width): + dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) + accumulator = 0.0 + for jj in range(size): + accumulator += w[jj, ii] * dy_dx[kk * size + jj] + dz_drou *= accumulator + if width == 2 * size or width == size: + dz_drou += dy_dx[kk * size + ii % size] + dz_dx[kk, ii] = dz_drou + + return dz_dx + + +def unaggregated_dy2_dx( + z: torch.Tensor, + w_np: np.ndarray, + dy_dx: torch.Tensor, + dy2_dx: torch.Tensor, + ybar: torch.Tensor, + functype: int, +): + w = torch.from_numpy(w_np).to(env.DEVICE) + if z.dim() != 2: + raise ValueError("z tensor must have 2 dimensions") + if w.dim() != 2: + raise ValueError("w tensor must have 2 dimensions") + if dy_dx.dim() != 2: + raise ValueError("dy_dx tensor must have 2 dimensions") + if dy2_dx.dim() != 2: + raise ValueError("dy2_dx tensor must have 2 dimensions") + if ybar.dim() != 2: + raise ValueError("ybar tensor must have 2 dimensions") + + length, width = z.shape + size = w.shape[0] + dy_dx = torch.flatten(dy_dx) + dy2_dx = torch.flatten(dy2_dx) + + dz2_dx = torch.zeros_like(z) + + for kk in range(length): + for ii in range(width): + dz_drou = grad(ybar[kk, ii], z[kk, ii], functype) + accumulator1 = 0.0 + for jj in range(size): + accumulator1 += w[jj, ii] * dy2_dx[kk * size + jj] + dz_drou *= accumulator1 + accumulator2 = 0.0 + for jj in range(size): + accumulator2 += w[jj, ii] * dy_dx[kk * size + jj] + dz_drou += ( + grad_grad(ybar[kk, ii], z[kk, ii], functype) + * accumulator2 + * accumulator2 + ) + if width == 2 * size or width == size: + dz_drou += dy2_dx[kk * size + ii % size] + dz2_dx[kk, ii] = dz_drou + + return dz2_dx diff --git a/deepmd/tf/utils/tabulate.py b/deepmd/tf/utils/tabulate.py index 588ebdd55e..30171b12db 100644 --- a/deepmd/tf/utils/tabulate.py +++ b/deepmd/tf/utils/tabulate.py @@ -2,7 +2,6 @@ import logging from functools import ( cached_property, - lru_cache, ) from typing import ( Callable, @@ -28,11 +27,14 @@ get_embedding_net_nodes_from_graph_def, get_tensor_by_name_from_graph, ) +from deepmd.utils.tabulate import ( + BaseTabulate, +) log = logging.getLogger(__name__) -class DPTabulate: +class DPTabulate(BaseTabulate): r"""Class for tabulation. Compress a model, which including tabulating the embedding-net. @@ -71,13 +73,18 @@ def __init__( activation_fn: Callable[[tf.Tensor], tf.Tensor] = tf.nn.tanh, suffix: str = "", ) -> None: + super().__init__( + descrpt, + neuron, + type_one_side, + exclude_types, + False, + ) + + self.descrpt_type = self._get_descrpt_type() """Constructor.""" - self.descrpt = descrpt - self.neuron = neuron self.graph = graph self.graph_def = graph_def - self.type_one_side = type_one_side - self.exclude_types = exclude_types self.suffix = suffix # functype @@ -156,271 +163,25 @@ def __init__( self.upper = {} self.lower = {} - def build( - self, min_nbor_dist: float, extrapolate: float, stride0: float, stride1: float - ) -> tuple[dict[str, int], dict[str, int]]: - r"""Build the tables for model compression. - - Parameters - ---------- - min_nbor_dist - The nearest distance between neighbor atoms - extrapolate - The scale of model extrapolation - stride0 - The uniform stride of the first table - stride1 - The uniform stride of the second table - - Returns - ------- - lower : dict[str, int] - The lower boundary of environment matrix by net - upper : dict[str, int] - The upper boundary of environment matrix by net - """ - # tabulate range [lower, upper] with stride0 'stride0' - lower, upper = self._get_env_mat_range(min_nbor_dist) - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( - self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2 - ): - uu = np.max(upper) - ll = np.min(lower) - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) - nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype( - int - ) - self._build_lower( - "filter_net", xx, 0, uu, ll, stride0, stride1, extrapolate, nspline - ) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - for ii in range(self.table_size): - if (self.type_one_side and not self._all_excluded(ii)) or ( - not self.type_one_side - and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types - ): - if self.type_one_side: - net = "filter_-1_net_" + str(ii) - # upper and lower should consider all types which are not excluded and sel>0 - idx = [ - (type_i, ii) not in self.exclude_types - and self.sel_a[type_i] > 0 - for type_i in range(self.ntypes) - ] - uu = np.max(upper[idx]) - ll = np.min(lower[idx]) - else: - ielement = ii // self.ntypes - net = ( - "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) - ) - uu = upper[ielement] - ll = lower[ielement] - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append( - xx, np.array([extrapolate * uu], dtype=self.data_type) - ) - nspline = ( - (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 - ).astype(int) - self._build_lower( - net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline - ) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - xx_all = [] - for ii in range(self.ntypes): - xx = np.arange( - extrapolate * lower[ii], lower[ii], stride1, dtype=self.data_type - ) - xx = np.append( - xx, np.arange(lower[ii], upper[ii], stride0, dtype=self.data_type) - ) - xx = np.append( - xx, - np.arange( - upper[ii], - extrapolate * upper[ii], - stride1, - dtype=self.data_type, - ), - ) - xx = np.append( - xx, np.array([extrapolate * upper[ii]], dtype=self.data_type) - ) - xx_all.append(xx) - nspline = ( - (upper - lower) / stride0 - + 2 * ((extrapolate * upper - upper) / stride1) - ).astype(int) - idx = 0 - for ii in range(self.ntypes): - for jj in range(ii, self.ntypes): - net = "filter_" + str(ii) + "_net_" + str(jj) - self._build_lower( - net, - xx_all[ii], - idx, - upper[ii], - lower[ii], - stride0, - stride1, - extrapolate, - nspline[ii], - ) - idx += 1 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - for ii in range(self.table_size): - if (self.type_one_side and not self._all_excluded(ii)) or ( - not self.type_one_side - and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types - ): - if self.type_one_side: - net = "filter_-1_net_" + str(ii) - # upper and lower should consider all types which are not excluded and sel>0 - idx = [ - (type_i, ii) not in self.exclude_types - and self.sel_a[type_i] > 0 - for type_i in range(self.ntypes) - ] - uu = np.max(upper[idx]) - ll = np.min(lower[idx]) - else: - ielement = ii // self.ntypes - net = ( - "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) - ) - uu = upper[ielement] - ll = lower[ielement] - xx = np.arange(ll, uu, stride0, dtype=self.data_type) - xx = np.append( - xx, - np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), - ) - xx = np.append( - xx, np.array([extrapolate * uu], dtype=self.data_type) - ) - nspline = ( - (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 - ).astype(int) - self._build_lower( - net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline - ) - else: - raise RuntimeError("Unsupported descriptor") - self._convert_numpy_to_tensor() - - return self.lower, self.upper - - def _build_lower( - self, net, xx, idx, upper, lower, stride0, stride1, extrapolate, nspline - ): - vv, dd, d2 = self._make_data(xx, idx) - self.data[net] = np.zeros( - [nspline, 6 * self.last_layer_size], dtype=self.data_type - ) - - # tt.shape: [nspline, self.last_layer_size] - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[: int((upper - lower) / stride0), :] = stride0 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[ - int((lower - extrapolate * lower) / stride1) + 1 : ( - int((lower - extrapolate * lower) / stride1) - + int((upper - lower) / stride0) - ), - :, - ] = stride0 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype - tt[: int((upper - lower) / stride0), :] = stride0 - else: - raise RuntimeError("Unsupported descriptor") - - # hh.shape: [nspline, self.last_layer_size] - hh = ( - vv[1 : nspline + 1, : self.last_layer_size] - - vv[:nspline, : self.last_layer_size] - ) - - self.data[net][:, : 6 * self.last_layer_size : 6] = vv[ - :nspline, : self.last_layer_size - ] - self.data[net][:, 1 : 6 * self.last_layer_size : 6] = dd[ - :nspline, : self.last_layer_size - ] - self.data[net][:, 2 : 6 * self.last_layer_size : 6] = ( - 0.5 * d2[:nspline, : self.last_layer_size] - ) - self.data[net][:, 3 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt) - ) * ( - 20 * hh - - ( - 8 * dd[1 : nspline + 1, : self.last_layer_size] - + 12 * dd[:nspline, : self.last_layer_size] - ) - * tt - - ( - 3 * d2[:nspline, : self.last_layer_size] - - d2[1 : nspline + 1, : self.last_layer_size] - ) - * tt - * tt - ) - self.data[net][:, 4 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt * tt) - ) * ( - -30 * hh - + ( - 14 * dd[1 : nspline + 1, : self.last_layer_size] - + 16 * dd[:nspline, : self.last_layer_size] - ) - * tt - + ( - 3 * d2[:nspline, : self.last_layer_size] - - 2 * d2[1 : nspline + 1, : self.last_layer_size] - ) - * tt - * tt - ) - self.data[net][:, 5 : 6 * self.last_layer_size : 6] = ( - 1 / (2 * tt * tt * tt * tt * tt) - ) * ( - 12 * hh - - 6 - * ( - dd[1 : nspline + 1, : self.last_layer_size] - + dd[:nspline, : self.last_layer_size] - ) - * tt - + ( - d2[1 : nspline + 1, : self.last_layer_size] - - d2[:nspline, : self.last_layer_size] - ) - * tt - * tt - ) - - self.upper[net] = upper - self.lower[net] = lower - def _load_sub_graph(self): sub_graph_def = tf.GraphDef() with tf.Graph().as_default() as sub_graph: tf.import_graph_def(sub_graph_def, name="") return sub_graph, sub_graph_def + def _get_descrpt_type(self): + if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten): + return "Atten" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2): + return "AEbdV2" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): + return "A" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): + return "T" + elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): + return "R" + raise RuntimeError(f"Unsupported descriptor {self.descrpt}") + def _get_bias(self): bias = {} for layer in range(1, self.layer_size + 1): @@ -711,36 +472,6 @@ def _layer_1(self, x, w, b): t = tf.concat([x, x], axis=1) return t, self.activation_fn(tf.matmul(x, w) + b) + t - # Change the embedding net range to sw / min_nbor_dist - def _get_env_mat_range(self, min_nbor_dist): - sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut) - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - lower = -self.davg[:, 0] / self.dstd[:, 0] - upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) - lower = np.min(-var, axis=1) - upper = np.max(var, axis=1) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - lower = -self.davg[:, 0] / self.dstd[:, 0] - upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] - else: - raise RuntimeError("Unsupported descriptor") - log.info("training data with lower boundary: " + str(lower)) - log.info("training data with upper boundary: " + str(upper)) - # returns element-wise lower and upper - return np.floor(lower), np.ceil(upper) - - def _spline5_switch(self, xx, rmin, rmax): - if xx < rmin: - vv = 1 - elif xx < rmax: - uu = (xx - rmin) / (rmax - rmin) - vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 - else: - vv = 0 - return vv - def _get_layer_size(self): layer_size = 0 if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( @@ -776,54 +507,6 @@ def _n_all_excluded(self) -> int: """Then number of types excluding all types.""" return sum(int(self._all_excluded(ii)) for ii in range(0, self.ntypes)) - @lru_cache - def _all_excluded(self, ii: int) -> bool: - """Check if type ii excluds all types. - - Parameters - ---------- - ii : int - type index - - Returns - ------- - bool - if type ii excluds all types - """ - return all((ii, type_i) in self.exclude_types for type_i in range(self.ntypes)) - - def _get_table_size(self): - table_size = 0 - if isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeAtten) or isinstance( - self.descrpt, deepmd.tf.descriptor.DescrptSeAEbdV2 - ): - table_size = 1 - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeA): - table_size = self.ntypes * self.ntypes - if self.type_one_side: - table_size = self.ntypes - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeT): - table_size = int(comb(self.ntypes + 1, 2)) - elif isinstance(self.descrpt, deepmd.tf.descriptor.DescrptSeR): - table_size = self.ntypes * self.ntypes - if self.type_one_side: - table_size = self.ntypes - else: - raise RuntimeError("Unsupported descriptor") - return table_size - - def _get_data_type(self): - for item in self.matrix["layer_" + str(self.layer_size)]: - if len(item) != 0: - return type(item[0][0]) - return None - - def _get_last_layer_size(self): - for item in self.matrix["layer_" + str(self.layer_size)]: - if len(item) != 0: - return item.shape[1] - return 0 - def _convert_numpy_to_tensor(self): """Convert self.data from np.ndarray to tf.Tensor.""" for ii in self.data: diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py new file mode 100644 index 0000000000..545b265b88 --- /dev/null +++ b/deepmd/utils/tabulate.py @@ -0,0 +1,458 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from abc import ( + ABC, + abstractmethod, +) +from functools import ( + lru_cache, +) + +import numpy as np +from scipy.special import ( + comb, +) + +log = logging.getLogger(__name__) + + +class BaseTabulate(ABC): + """A base class for pt and tf tabulation.""" + + def __init__( + self, + descrpt, + neuron, + type_one_side, + exclude_types, + is_pt, + ) -> None: + """Constructor.""" + super().__init__() + + """Shared attributes.""" + self.descrpt = descrpt + self.neuron = neuron + self.type_one_side = type_one_side + self.exclude_types = exclude_types + self.is_pt = is_pt + + """Need to be initialized in the subclass.""" + self.descrpt_type = "Base" + + self.sel_a = [] + self.rcut = 0.0 + self.rcut_smth = 0.0 + + self.davg = np.array([]) + self.dstd = np.array([]) + self.ntypes = 0 + + self.layer_size = 0 + self.table_size = 0 + + self.bias = {} + self.matrix = {} + + self.data_type = None + self.last_layer_size = 0 + + """Save the tabulation result.""" + self.data = {} + + self.upper = {} + self.lower = {} + + def build( + self, min_nbor_dist: float, extrapolate: float, stride0: float, stride1: float + ) -> tuple[dict[str, int], dict[str, int]]: + r"""Build the tables for model compression. + + Parameters + ---------- + min_nbor_dist + The nearest distance between neighbor atoms + extrapolate + The scale of model extrapolation + stride0 + The uniform stride of the first table + stride1 + The uniform stride of the second table + + Returns + ------- + lower : dict[str, int] + The lower boundary of environment matrix by net + upper : dict[str, int] + The upper boundary of environment matrix by net + """ + # tabulate range [lower, upper] with stride0 'stride0' + lower, upper = self._get_env_mat_range(min_nbor_dist) + if self.descrpt_type in ("Atten", "AEbdV2"): + uu = np.max(upper) + ll = np.min(lower) + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) + nspline = ((uu - ll) / stride0 + (extrapolate * uu - uu) / stride1).astype( + int + ) + self._build_lower( + "filter_net", xx, 0, uu, ll, stride0, stride1, extrapolate, nspline + ) + elif self.descrpt_type == "A": + for ii in range(self.table_size): + if (self.type_one_side and not self._all_excluded(ii)) or ( + not self.type_one_side + and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types + ): + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [ + (type_i, ii) not in self.exclude_types + and self.sel_a[type_i] > 0 + for type_i in range(self.ntypes) + ] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) + else: + ielement = ii // self.ntypes + net = ( + "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + ) + if self.is_pt: + uu = np.max(upper[ielement]) + ll = np.min(lower[ielement]) + else: + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append( + xx, np.array([extrapolate * uu], dtype=self.data_type) + ) + nspline = ( + (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 + ).astype(int) + self._build_lower( + net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline + ) + elif self.descrpt_type == "T": + xx_all = [] + for ii in range(self.ntypes): + """Pt and tf is different here. Pt version is a two-dimensional array.""" + if self.is_pt: + uu = np.max(upper[ii]) + ll = np.min(lower[ii]) + else: + ll = lower[ii] + uu = upper[ii] + xx = np.arange(extrapolate * ll, ll, stride1, dtype=self.data_type) + xx = np.append(xx, np.arange(ll, uu, stride0, dtype=self.data_type)) + xx = np.append( + xx, + np.arange( + uu, + extrapolate * uu, + stride1, + dtype=self.data_type, + ), + ) + xx = np.append(xx, np.array([extrapolate * uu], dtype=self.data_type)) + xx_all.append(xx) + nspline = ( + (upper - lower) / stride0 + + 2 * ((extrapolate * upper - upper) / stride1) + ).astype(int) + idx = 0 + for ii in range(self.ntypes): + if self.is_pt: + uu = np.max(upper[ii]) + ll = np.min(lower[ii]) + else: + ll = lower[ii] + uu = upper[ii] + for jj in range(ii, self.ntypes): + net = "filter_" + str(ii) + "_net_" + str(jj) + self._build_lower( + net, + xx_all[ii], + idx, + uu, + ll, + stride0, + stride1, + extrapolate, + nspline[ii][0] if self.is_pt else nspline[ii], + ) + idx += 1 + elif self.descrpt_type == "R": + for ii in range(self.table_size): + if (self.type_one_side and not self._all_excluded(ii)) or ( + not self.type_one_side + and (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types + ): + if self.type_one_side: + net = "filter_-1_net_" + str(ii) + # upper and lower should consider all types which are not excluded and sel>0 + idx = [ + (type_i, ii) not in self.exclude_types + and self.sel_a[type_i] > 0 + for type_i in range(self.ntypes) + ] + uu = np.max(upper[idx]) + ll = np.min(lower[idx]) + else: + ielement = ii // self.ntypes + net = ( + "filter_" + str(ielement) + "_net_" + str(ii % self.ntypes) + ) + uu = upper[ielement] + ll = lower[ielement] + xx = np.arange(ll, uu, stride0, dtype=self.data_type) + xx = np.append( + xx, + np.arange(uu, extrapolate * uu, stride1, dtype=self.data_type), + ) + xx = np.append( + xx, np.array([extrapolate * uu], dtype=self.data_type) + ) + nspline = ( + (uu - ll) / stride0 + (extrapolate * uu - uu) / stride1 + ).astype(int) + self._build_lower( + net, xx, ii, uu, ll, stride0, stride1, extrapolate, nspline + ) + else: + raise RuntimeError("Unsupported descriptor") + + self._convert_numpy_to_tensor() + if self.is_pt: + self._convert_numpy_float_to_int() + return self.lower, self.upper + + def _build_lower( + self, net, xx, idx, upper, lower, stride0, stride1, extrapolate, nspline + ): + vv, dd, d2 = self._make_data(xx, idx) + self.data[net] = np.zeros( + [nspline, 6 * self.last_layer_size], dtype=self.data_type + ) + + # tt.shape: [nspline, self.last_layer_size] + if self.descrpt_type in ("Atten", "A", "AEbdV2"): + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[: int((upper - lower) / stride0), :] = stride0 + elif self.descrpt_type == "T": + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[ + int((lower - extrapolate * lower) / stride1) + 1 : ( + int((lower - extrapolate * lower) / stride1) + + int((upper - lower) / stride0) + ), + :, + ] = stride0 + elif self.descrpt_type == "R": + tt = np.full((nspline, self.last_layer_size), stride1) # pylint: disable=no-explicit-dtype + tt[: int((upper - lower) / stride0), :] = stride0 + else: + raise RuntimeError("Unsupported descriptor") + + # hh.shape: [nspline, self.last_layer_size] + hh = ( + vv[1 : nspline + 1, : self.last_layer_size] + - vv[:nspline, : self.last_layer_size] + ) + + self.data[net][:, : 6 * self.last_layer_size : 6] = vv[ + :nspline, : self.last_layer_size + ] + self.data[net][:, 1 : 6 * self.last_layer_size : 6] = dd[ + :nspline, : self.last_layer_size + ] + self.data[net][:, 2 : 6 * self.last_layer_size : 6] = ( + 0.5 * d2[:nspline, : self.last_layer_size] + ) + self.data[net][:, 3 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt) + ) * ( + 20 * hh + - ( + 8 * dd[1 : nspline + 1, : self.last_layer_size] + + 12 * dd[:nspline, : self.last_layer_size] + ) + * tt + - ( + 3 * d2[:nspline, : self.last_layer_size] + - d2[1 : nspline + 1, : self.last_layer_size] + ) + * tt + * tt + ) + self.data[net][:, 4 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt * tt) + ) * ( + -30 * hh + + ( + 14 * dd[1 : nspline + 1, : self.last_layer_size] + + 16 * dd[:nspline, : self.last_layer_size] + ) + * tt + + ( + 3 * d2[:nspline, : self.last_layer_size] + - 2 * d2[1 : nspline + 1, : self.last_layer_size] + ) + * tt + * tt + ) + self.data[net][:, 5 : 6 * self.last_layer_size : 6] = ( + 1 / (2 * tt * tt * tt * tt * tt) + ) * ( + 12 * hh + - 6 + * ( + dd[1 : nspline + 1, : self.last_layer_size] + + dd[:nspline, : self.last_layer_size] + ) + * tt + + ( + d2[1 : nspline + 1, : self.last_layer_size] + - d2[:nspline, : self.last_layer_size] + ) + * tt + * tt + ) + + self.upper[net] = upper + self.lower[net] = lower + + @abstractmethod + def _make_data(self, xx, idx) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Generate tabulation data for the given input. + + Parameters + ---------- + xx : np.ndarray + Input values to tabulate + idx : int + Index for accessing the correct network parameters + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + Values, first derivatives, and second derivatives + """ + pass + + @lru_cache + def _all_excluded(self, ii: int) -> bool: + """Check if type ii excluds all types. + + Parameters + ---------- + ii : int + type index + + Returns + ------- + bool + if type ii excluds all types + """ + return all((ii, type_i) in self.exclude_types for type_i in range(self.ntypes)) + + @abstractmethod + def _get_descrpt_type(self): + """Get the descrpt type.""" + pass + + @abstractmethod + def _get_layer_size(self): + """Get the number of embedding layer.""" + pass + + def _get_table_size(self): + table_size = 0 + if self.descrpt_type in ("Atten", "AEbdV2"): + table_size = 1 + elif self.descrpt_type == "A": + table_size = self.ntypes * self.ntypes + if self.type_one_side: + table_size = self.ntypes + elif self.descrpt_type == "T": + table_size = int(comb(self.ntypes + 1, 2)) + elif self.descrpt_type == "R": + table_size = self.ntypes * self.ntypes + if self.type_one_side: + table_size = self.ntypes + else: + raise RuntimeError("Unsupported descriptor") + return table_size + + def _get_data_type(self): + for item in self.matrix["layer_" + str(self.layer_size)]: + if len(item) != 0: + return type(item[0][0]) + return None + + def _get_last_layer_size(self): + for item in self.matrix["layer_" + str(self.layer_size)]: + if len(item) != 0: + return item.shape[1] + return 0 + + @abstractmethod + def _get_bias(self): + """Get bias of embedding net.""" + pass + + @abstractmethod + def _get_matrix(self): + """Get weight matrx of embedding net.""" + pass + + @abstractmethod + def _convert_numpy_to_tensor(self): + """Convert self.data from np.ndarray to torch.Tensor.""" + pass + + def _convert_numpy_float_to_int(self): + """Convert self.lower and self.upper from np.float32 or np.float64 to int.""" + self.lower = {k: int(v) for k, v in self.lower.items()} + self.upper = {k: int(v) for k, v in self.upper.items()} + + def _get_env_mat_range(self, min_nbor_dist): + """Change the embedding net range to sw / min_nbor_dist.""" + sw = self._spline5_switch(min_nbor_dist, self.rcut_smth, self.rcut) + if self.descrpt_type in ("Atten", "A", "AEbdV2"): + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] + elif self.descrpt_type == "T": + var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) + lower = np.min(-var, axis=1) + upper = np.max(var, axis=1) + elif self.descrpt_type == "R": + lower = -self.davg[:, 0] / self.dstd[:, 0] + upper = ((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0] + else: + raise RuntimeError("Unsupported descriptor") + log.info("training data with lower boundary: " + str(lower)) + log.info("training data with upper boundary: " + str(upper)) + # returns element-wise lower and upper + return np.floor(lower), np.ceil(upper) + + def _spline5_switch(self, xx, rmin, rmax): + if xx < rmin: + vv = 1 + elif xx < rmax: + uu = (xx - rmin) / (rmax - rmin) + vv = uu * uu * uu * (-6 * uu * uu + 15 * uu - 10) + 1 + else: + vv = 0 + return vv diff --git a/source/op/pt/tabulate_multi_device.cc b/source/op/pt/tabulate_multi_device.cc index bdc6f63f94..5c710f5c37 100644 --- a/source/op/pt/tabulate_multi_device.cc +++ b/source/op/pt/tabulate_multi_device.cc @@ -905,7 +905,7 @@ class TabulateFusionSeROp std::vector tabulate_fusion_se_a( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, int64_t last_layer_size) { @@ -915,7 +915,7 @@ std::vector tabulate_fusion_se_a( std::vector tabulate_fusion_se_atten( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, const torch::Tensor& two_embed_tensor, @@ -928,7 +928,7 @@ std::vector tabulate_fusion_se_atten( std::vector tabulate_fusion_se_t( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_x_tensor, const torch::Tensor& em_tensor, int64_t last_layer_size) { @@ -938,7 +938,7 @@ std::vector tabulate_fusion_se_t( std::vector tabulate_fusion_se_r( const torch::Tensor& table_tensor, - const torch::Tensor& table_info_tensor, + const torch::Tensor& table_info_tensor, // only cpu const torch::Tensor& em_tensor, int64_t last_layer_size) { return TabulateFusionSeROp::apply(table_tensor, table_info_tensor, em_tensor, diff --git a/source/tests/pt/model/test_compressed_descriptor_se_a.py b/source/tests/pt/model/test_compressed_descriptor_se_a.py new file mode 100644 index 0000000000..14d82a452c --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_a.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorSeA(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_a = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + self.axis_neuron, + type_one_side=self.type_one_side, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_a, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_a.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_a, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_atten.py b/source/tests/pt/model/test_compressed_descriptor_se_atten.py new file mode 100644 index 0000000000..a439255396 --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_atten.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64"), (True, False)) +class TestDescriptorSeAtten(unittest.TestCase): + def setUp(self): + (self.dtype, self.type_one_side) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.axis_neuron = 3 + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_atten = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel, + self.ntypes, + self.neuron, + self.axis_neuron, + 4, + attn=8, + attn_layer=0, + seed=21, + precision=self.dtype, + type_one_side=self.type_one_side, + tebd_input_mode="strip", + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_atten, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + if self.dtype == "float32": + result_pt = result_pt.to(torch.float32) + elif self.dtype == "float64": + result_pt = result_pt.to(torch.float64) + + self.se_atten.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_atten, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_r.py b/source/tests/pt/model/test_compressed_descriptor_se_r.py new file mode 100644 index 0000000000..156cb9a06d --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_r.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64")) +class TestDescriptorSeR(unittest.TestCase): + def setUp(self): + (self.dtype,) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_r = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_r, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_r.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_r, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_compressed_descriptor_se_t.py b/source/tests/pt/model/test_compressed_descriptor_se_t.py new file mode 100644 index 0000000000..aa3054bc0d --- /dev/null +++ b/source/tests/pt/model/test_compressed_descriptor_se_t.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.pt.model.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt.utils.env import DEVICE as PT_DEVICE +from deepmd.pt.utils.nlist import build_neighbor_list as build_neighbor_list_pt +from deepmd.pt.utils.nlist import ( + extend_coord_with_ghosts as extend_coord_with_ghosts_pt, +) + +from ...consistent.common import ( + parameterized, +) + + +def eval_pt_descriptor( + pt_obj: Any, natoms, coords, atype, box, mixed_types: bool = False +) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_obj.get_rcut(), + ) + nlist = build_neighbor_list_pt( + ext_coords, + ext_atype, + natoms[0], + pt_obj.get_rcut(), + pt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + result, _, _, _, _ = pt_obj(ext_coords, ext_atype, nlist, mapping=mapping) + return result + + +@parameterized(("float32", "float64")) +class TestDescriptorSeT(unittest.TestCase): + def setUp(self): + (self.dtype,) = self.param + if self.dtype == "float32": + self.atol = 1e-5 + elif self.dtype == "float64": + self.atol = 1e-10 + self.seed = 21 + self.sel = [9, 10] + self.rcut_smth = 5.80 + self.rcut = 6.00 + self.neuron = [6, 12, 24] + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + self.se_t = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + self.neuron, + seed=21, + precision=self.dtype, + ) + + def test_compressed_forward(self): + result_pt = eval_pt_descriptor( + self.se_t, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.se_t.enable_compression(0.5) + result_pt_compressed = eval_pt_descriptor( + self.se_t, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + self.assertEqual(result_pt.shape, result_pt_compressed.shape) + torch.testing.assert_close( + result_pt, + result_pt_compressed, + atol=self.atol, + rtol=self.atol, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_tabulate.py b/source/tests/pt/test_tabulate.py new file mode 100644 index 0000000000..c03773827d --- /dev/null +++ b/source/tests/pt/test_tabulate.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.tabulate import ( + unaggregated_dy2_dx, + unaggregated_dy2_dx_s, + unaggregated_dy_dx, + unaggregated_dy_dx_s, +) +from deepmd.tf.env import ( + op_module, + tf, +) + + +def setUpModule(): + tf.compat.v1.enable_eager_execution() + + +def tearDownModule(): + tf.compat.v1.disable_eager_execution() + + +class TestDPTabulate(unittest.TestCase): + def setUp(self): + self.w = np.array( + [[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]], + dtype=np.float64, + ) + + self.x = np.array( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [1.0, 1.1, 1.2]], + dtype=np.float64, # 4 x 3 + ) + + self.b = np.array([[0.1], [0.2], [0.3], [0.4]], dtype=np.float64) # 4 x 1 + + self.xbar = np.matmul(self.x, self.w) + self.b # 4 x 4 + + self.y = np.tanh(self.xbar) + + def test_ops(self): + dy_tf = op_module.unaggregated_dy_dx_s( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy_pt = unaggregated_dy_dx_s( + torch.from_numpy(self.y), + self.w, + torch.from_numpy(self.xbar), + 1, + ) + + dy_tf_numpy = dy_tf.numpy() + dy_pt_numpy = dy_pt.detach().numpy() + + np.testing.assert_almost_equal(dy_tf_numpy, dy_pt_numpy, decimal=10) + + dy2_tf = op_module.unaggregated_dy2_dx_s( + tf.constant(self.y, dtype="double"), + dy_tf, + tf.constant(self.w, dtype="double"), + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy2_pt = unaggregated_dy2_dx_s( + torch.from_numpy(self.y), + dy_pt, + self.w, + torch.from_numpy(self.xbar), + 1, + ) + + dy2_tf_numpy = dy2_tf.numpy() + dy2_pt_numpy = dy2_pt.detach().numpy() + + np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + + dz_tf = op_module.unaggregated_dy_dx( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + dy_tf, + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dz_pt = unaggregated_dy_dx( + torch.from_numpy(self.y).to(env.DEVICE), + self.w, + dy_pt, + torch.from_numpy(self.xbar).to(env.DEVICE), + 1, + ) + + dz_tf_numpy = dz_tf.numpy() + dz_pt_numpy = dz_pt.detach().cpu().numpy() + + np.testing.assert_almost_equal(dz_tf_numpy, dz_pt_numpy, decimal=10) + + dy2_tf = op_module.unaggregated_dy2_dx( + tf.constant(self.y, dtype="double"), + tf.constant(self.w, dtype="double"), + dy_tf, + dy2_tf, + tf.constant(self.xbar, dtype="double"), + tf.constant(1), + ) + + dy2_pt = unaggregated_dy2_dx( + torch.from_numpy(self.y).to(env.DEVICE), + self.w, + dy_pt, + dy2_pt, + torch.from_numpy(self.xbar).to(env.DEVICE), + 1, + ) + + dy2_tf_numpy = dy2_tf.numpy() + dy2_pt_numpy = dy2_pt.detach().cpu().numpy() + + np.testing.assert_almost_equal(dy2_tf_numpy, dy2_pt_numpy, decimal=10) + + +if __name__ == "__main__": + unittest.main()