From 77129758da04293d082fa1f3e1a63df9e715bdc0 Mon Sep 17 00:00:00 2001 From: CheukHinHoJerry Date: Mon, 28 Oct 2024 12:49:44 -0700 Subject: [PATCH] finishing up sym/non-sym cp/tucker and small clean up --- mace/modules/blocks.py | 32 +--- mace/modules/models.py | 2 +- mace/modules/symmetric_contraction.py | 254 +++++++++++--------------- mace/tools/cg.py | 5 +- 4 files changed, 122 insertions(+), 171 deletions(-) diff --git a/mace/modules/blocks.py b/mace/modules/blocks.py index 677948de..c104b924 100644 --- a/mace/modules/blocks.py +++ b/mace/modules/blocks.py @@ -228,6 +228,9 @@ def tensor_power_einsum(tensor, N): @compile_mode("script") class TensorFormatBlock(torch.nn.Module): + """ + Maybe useful for reshaping tensor for efficient operation later. + """ def __init__(self, tensor_format, correlation): super().__init__() @@ -239,17 +242,10 @@ def __init__(self, tensor_format, correlation): def forward(self, message) -> torch.Tensor: batch_size, dim, features = message.shape - if self.tensor_format == "symmetric_cp": + if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]: return message - elif self.tensor_format == "symmetric_tucker": + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: return message - # message = [message] * correlation - # message = torch.einsum(eq, *tensors) - # # K = message.shape[-2] - # # for i in range(self.correlation - 1): - # # message = message.unsqueeze(-2) - # # message = message.repeat([1, 1, ] + [K, ] * (self.correlation - 1) + [1]) - # return message.reshape(batch_size, dim ** correlation, features) @compile_mode("script") @@ -281,7 +277,7 @@ def __init__( internal_weights=True, shared_weights=True, ) - elif tensor_format == "symmetric_tucker": + elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]: tucker_irreps = make_tucker_irreps(target_irreps, correlation) self.linear = o3.Linear( tucker_irreps, @@ -687,7 +683,7 @@ def _setup(self) -> None: irreps_mid = irreps_mid.simplify() self.irreps_out = self.target_irreps - if self.tensor_format == "symmetric_cp": + if self.tensor_format in ["symmetric_cp", "symmetric_tucker"]: self.linear = o3.Linear( irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True ) @@ -697,7 +693,7 @@ def _setup(self) -> None: ) self.reshape = reshape_irreps(self.irreps_out) - elif self.tensor_format == "non_symmetric_cp": + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: self.linear = [] # Selector TensorProduct self.skip_tp = o3.FullyConnectedTensorProduct( @@ -710,16 +706,6 @@ def _setup(self) -> None: )) self.reshape.append(reshape_irreps(self.irreps_out)) - elif self.tensor_format == "symmetric_tucker": - self.linear = o3.Linear( - irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True - ) - # Selector TensorProduct - self.skip_tp = o3.FullyConnectedTensorProduct( - self.node_feats_irreps, self.node_attrs_irreps, self.hidden_irreps - ) - self.reshape = reshape_irreps(self.irreps_out) - # self.reshape = reshape_irreps(self.irreps_out) self.tensor_format_layer = TensorFormatBlock(self.tensor_format, self.correlation) @@ -750,7 +736,7 @@ def forward( self.tensor_format_layer(self.reshape(message)), sc, ) # symmetric_cp: [n_nodes, channels, (lmax + 1)**2] - elif self.tensor_format == "non_symmetric_cp": + elif self.tensor_format in ["non_symmetric_cp", "non_symmetric_tucker"]: message = self.reshape[0](self.linear[0](original_message)) message = message.unsqueeze(-1) for idx in range(1, self.correlation): diff --git a/mace/modules/models.py b/mace/modules/models.py index 2991c8a6..2878b8de 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -149,7 +149,7 @@ def __init__( self.readouts.append( LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) ) - elif tensor_format == "symmetric_tucker": + elif tensor_format in ["symmetric_tucker", "non_symmetric_tucker"]: self.readouts.append( #LinearReadoutBlock(make_tp_irreps(hidden_irreps, correlation[0]), o3.Irreps(f"{len(heads)}x0e")) LinearReadoutBlock(hidden_irreps, o3.Irreps(f"{len(heads)}x0e")) diff --git a/mace/modules/symmetric_contraction.py b/mace/modules/symmetric_contraction.py index 494c2d6a..14ba3ce1 100644 --- a/mace/modules/symmetric_contraction.py +++ b/mace/modules/symmetric_contraction.py @@ -101,7 +101,6 @@ def __init__( tensor_format: str = "symmetric_cp", ) -> None: super().__init__() - self.num_features = irreps_in.count((0, 1)) self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in]) self.correlation = correlation @@ -125,17 +124,13 @@ def __init__( # Create weight for product basis self.weights = torch.nn.ParameterList([]) - # for tucker decomposition for i in range(correlation, 0, -1): - - print(f"initing i = {i}") # Shapes definying num_params = self.U_tensors(i).size()[-1] num_equivariance = 2 * irrep_out.lmax + 1 num_ell = self.U_tensors(i).size()[-2] if i == correlation: - if tensor_format in ["symmetric_cp", "non_symmetric_cp"]: channel_idx = "c" sample_x = torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)) @@ -144,45 +139,42 @@ def __init__( / num_params ) elif tensor_format == "symmetric_tucker": - #channel_idx = "".join([CHANNEL_ALPHANET[j] for j in range(self.correlation)]) - channel_idx = "c" - sample_x = torch.randn([BATCH_EXAMPLE, ] + [self.num_features, ] + [num_ell, ]) w = torch.nn.Parameter( torch.randn([num_elements, num_params,] + [self.num_features,]) / num_params ) - - parse_subscript_main = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] - + [f"ik,ek{channel_idx},b{channel_idx}i,be -> b{channel_idx}"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + elif tensor_format == "non_symmetric_tucker": + w = torch.nn.Parameter( + torch.randn([num_elements, num_params,] + [self.num_features,] * self.correlation) + / num_params ) - print("".join(parse_subscript_main)) - graph_module_main = torch.fx.symbolic_trace( - lambda x, y, w, z: torch.einsum( - "".join(parse_subscript_main), x, y, w, z + + # optimize contraction only implemented for cp + if "cp" in self.tensor_format: + parse_subscript_main = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + + [f"ik,ek{channel_idx},b{channel_idx}i,be -> b{channel_idx}"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)] + ) + graph_module_main = torch.fx.symbolic_trace( + lambda x, y, w, z: torch.einsum( + "".join(parse_subscript_main), x, y, w, z + ) ) - ) - # Optimizing the contractions - self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( - model=graph_module_main, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - #torch.randn((num_elements, num_params, self.num_features)), - #torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - torch.randn(w.shape), - sample_x, - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) + # Optimizing the contractions + self.graph_opt_main = opt_einsum_fx.optimize_einsums_full( + model=graph_module_main, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + torch.randn(w.shape), + sample_x, + torch.randn((BATCH_EXAMPLE, num_elements)), + ), + ) # Parameters for the product basis - # w = torch.nn.Parameter( - # torch.randn((num_elements, num_params, self.num_features)) - # / num_params - # ) self.weights_max = w else: @@ -198,85 +190,67 @@ def __init__( / num_params ) elif tensor_format == "symmetric_tucker": - out_channel_idx = "".join([CHANNEL_ALPHANET[j] for j in range(self.correlation - i + 1)]) - #in_channel_idx = out_channel_idx[:-1] - #channel_idx = "c" - sample_x = torch.randn([BATCH_EXAMPLE, self.num_features, num_ell]) - # order of features is of length out_channel_idx - 1 - sample_x2 = torch.randn( - [BATCH_EXAMPLE,] + [self.num_features,] * (self.correlation - i) + [num_equivariance, ] - + [num_ell] * i - ).squeeze(self.correlation - i + 1) - print("sample_x.shape: ", sample_x.shape) - print("sample_x2.shap: ", sample_x2.shape) + # to be outer produced in model.forward to form symemtrized parameter tensor + # this can be improved w = torch.nn.Parameter( torch.randn((num_elements, num_params, self.num_features)) / num_params ) - # Generate optimized contractions equations - parse_subscript_weighting = ( - [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - + [f"k,ekc,be->bc"] - + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] - ) - if tensor_format in ["symmetric_cp", "non_symmetric_cp"]: + elif tensor_format == "non_symmetric_tucker": + # size of channel of the weight tensor is the current correlation order : (i) + w = torch.nn.Parameter( + torch.randn((num_elements, num_params, *([self.num_features] * i))) + / num_params + ) + + # optimized contraction implemented for cp only + if "cp" in tensor_format: + # Generate optimized contractions equations + parse_subscript_weighting = ( + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + + [f"k,ekc,be->bc"] + + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))] + ) parse_subscript_features = ( [f"bc"] + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + [f"i,bci->bc"] + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] ) - elif tensor_format == "symmetric_tucker": - parse_subscript_features = ( - [f"b{out_channel_idx[:-1]}"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] - + [f"i,b{out_channel_idx[-1]}i->b{out_channel_idx}"] - + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))] + + # Symbolic tracing of contractions + graph_module_weighting = torch.fx.symbolic_trace( + lambda x, y, z: torch.einsum( + "".join(parse_subscript_weighting), x, y, z + ) + ) + graph_module_features = torch.fx.symbolic_trace( + lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) ) - print("weighting: ", "".join(parse_subscript_weighting)) - print("features: ", "".join(parse_subscript_features)) - # Symbolic tracing of contractions - graph_module_weighting = torch.fx.symbolic_trace( - lambda x, y, z: torch.einsum( - "".join(parse_subscript_weighting), x, y, z + # Optimizing the contractions + graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( + model=graph_module_weighting, + example_inputs=( + torch.randn( + [num_equivariance] + [num_ell] * i + [num_params] + ).squeeze(0), + #torch.randn((num_elements, num_params, self.num_features)), + torch.randn(w.shape), + torch.randn((BATCH_EXAMPLE, num_elements)), + ), ) - ) - graph_module_features = torch.fx.symbolic_trace( - lambda x, y: torch.einsum("".join(parse_subscript_features), x, y) - ) + graph_opt_features = opt_einsum_fx.optimize_einsums_full( + model=graph_module_features, + example_inputs=( + sample_x2, + sample_x, + ), + ) + self.contractions_weighting.append(graph_opt_weighting) + self.contractions_features.append(graph_opt_features) - # Optimizing the contractions - graph_opt_weighting = opt_einsum_fx.optimize_einsums_full( - model=graph_module_weighting, - example_inputs=( - torch.randn( - [num_equivariance] + [num_ell] * i + [num_params] - ).squeeze(0), - #torch.randn((num_elements, num_params, self.num_features)), - torch.randn(w.shape), - torch.randn((BATCH_EXAMPLE, num_elements)), - ), - ) - graph_opt_features = opt_einsum_fx.optimize_einsums_full( - model=graph_module_features, - example_inputs=( - # torch.randn( - # [BATCH_EXAMPLE, self.num_features, num_equivariance] - # + [num_ell] * i - # ).squeeze(2), - #torch.randn((BATCH_EXAMPLE, self.num_features, num_ell)), - sample_x2, - sample_x, - ), - ) - self.contractions_weighting.append(graph_opt_weighting) - self.contractions_features.append(graph_opt_features) - # Parameters for the product basis - # w = torch.nn.Parameter( - # torch.randn((num_elements, num_params, self.num_features)) - # / num_params - # ) + # coefficients for product basis self.weights.append(w) if not internal_weights: self.weights = weights[:-1] @@ -284,31 +258,30 @@ def __init__( def forward(self, x: torch.Tensor, y: torch.Tensor): irrep_out = self.irrep_out num_equivariance = 2 * irrep_out.lmax + 1 - # print("=== into forward ===") - if self.tensor_format == "symmetric_tucker": + if "tucker" in self.tensor_format: # outs = dict() out_channel_idx = "".join([CHANNEL_ALPHANET[j] for j in range(self.correlation)]) for nu in range(self.correlation, 0, -1): num_params = self.U_tensors(nu).size()[-1] num_ell = self.U_tensors(nu).size()[-2] - + # "stick" x_nu to basis of all orders + x_nu = x[:, :, :, nu - 1] if self.tensor_format == "non_symmetric_tucker" else x # channel index of the final message # in symmetric tucker we cannot sum over the basis of different order # and will just produce "m_\tilde{k}LM" with \tilde{k} as tuple if nu == self.correlation: - einsum_str = "".join(( - [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] - + [f"ik,b{out_channel_idx[-1]}i-> b{out_channel_idx[-1]}"] - + [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] - + ["k"] - - )) - outs[nu] = torch.einsum(einsum_str, self.U_tensors(nu), x) + outs[nu] = torch.einsum("".join(( + [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] + + [f"ik,b{out_channel_idx[-1]}i-> b{out_channel_idx[-1]}"] + + [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] + + ["k"] + )), + self.U_tensors(nu), + x_nu) else: # contractions to be done for U_tensors(nu) for nu2 in range(self.correlation, nu - 1, -1): - # print("Utensor nu2 .shape", self.U_tensors(nu).shape) # contraction for current nu # [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] # denotes indices that are preserved for further contractions @@ -321,17 +294,10 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): )) , self.U_tensors(nu), - x + x_nu ) # also contract previous nu and expand the tensor product basis - else: - # print(f"---doing nu = {nu}, nu2 = {nu2}---") - # for nu_ii in range(self.correlation, 0, -1): - # try: - # print(f"outs{nu_ii}.shape", outs[nu_ii].shape) - # except: - # print(f"dict no {nu_ii}") - + else: outs[nu2] = torch.einsum( "".join([f"b{out_channel_idx[-(nu2 - nu):]}"] + [ALPHABET[j] for j in range(nu + min(irrep_out.lmax, 1) - 1)] @@ -341,30 +307,36 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): + ["k"] ), outs[nu2], - x + x_nu ) - # for nu_ii in range(self.correlation, 0, -1): - # try: - # print(f"outs{nu_ii}.shape", outs[nu_ii].shape) - # except: - # print(f"dict no {nu_ii}") # product basis coeffcients layer - #print("=== done contract===") - # for nu_ii in range(self.correlation, 0, -1): - # print(f"outs{nu_ii}.shape", outs[nu_ii].shape) - for nu in range(self.correlation, 0, -1): if nu == self.correlation: - c_tensor = torch.einsum("ekc,be->bkc", self.weights_max, y) + if self.tensor_format == "non_symmetric_tucker": + c_tensor = torch.einsum(f"ek{out_channel_idx[:nu]},be->bk{out_channel_idx[:nu]}", self.weights_max, y) + elif self.tensor_format == "symmetric_tucker": + c_tensor = torch.einsum("ekc,be->bkc", self.weights_max, y) + # outer product to symmetrize tensor + c_tensor = torch.einsum("".join([f"bk{out_channel_idx[i]}," for i in range(nu-1)] + +[f"bk{out_channel_idx[nu-1]}"] + +[f"->bk{out_channel_idx[:nu]}"] + ), + *[c_tensor for _ in range(nu)]) + else: - c_tensor = torch.einsum("ekc,be->bkc", self.weights[self.correlation - nu - 1], y) + if self.tensor_format == "non_symmetric_tucker": + c_tensor = torch.einsum(f"ek{out_channel_idx[:nu]},be->bk{out_channel_idx[:nu]}", self.weights[self.correlation - nu - 1], y) + elif self.tensor_format == "symmetric_tucker": + c_tensor = torch.einsum("ekc,be->bkc", self.weights[self.correlation - nu - 1], y) + # outer product to symmetrize tensor + c_tensor = torch.einsum("".join([f"bk{out_channel_idx[i]}," for i in range(nu-1)] + +[f"bk{out_channel_idx[nu-1]}"] + +[f"->bk{out_channel_idx[:nu]}"] + ), + *[c_tensor for _ in range(nu)]) - c_tensor = torch.einsum("".join([f"bk{out_channel_idx[i]}," for i in range(nu-1)] - +[f"bk{out_channel_idx[nu-1]}"] - +[f"->bk{out_channel_idx[:nu]}"] - ), - *[c_tensor for _ in range(nu)]) + outs[nu] = torch.einsum( "".join( [f"b{out_channel_idx[:nu]}"] @@ -376,10 +348,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): outs[nu], c_tensor, ) - # print("=== done linear") - # for nu_ii in range(self.correlation, 0, -1): - # print(f"outs{nu_ii}.shape", outs[nu_ii].shape) - for nu in range(self.correlation, 0, -1): shape_outnu = [outs[nu].shape[0]] + [self.num_features] * nu if irrep_out.lmax > 0: diff --git a/mace/tools/cg.py b/mace/tools/cg.py index ec624bf5..dc56a071 100644 --- a/mace/tools/cg.py +++ b/mace/tools/cg.py @@ -128,7 +128,4 @@ def U_matrix_real( else: current_ir = ir out += [last_ir, stack] - if tensor_format in ["symmetric_cp", "non_symmetric_cp"]: - return out - elif tensor_format == "symmetric_tucker": - return out + return out