From 79152f936e8052d78c4106ae57d71c30cb847fc7 Mon Sep 17 00:00:00 2001 From: kt Date: Wed, 13 Mar 2024 15:35:08 -0400 Subject: [PATCH] fix n_confs from changing in reshape_conformation_size --- espfit/utils/graphs.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/espfit/utils/graphs.py b/espfit/utils/graphs.py index a9f7c20..9673724 100644 --- a/espfit/utils/graphs.py +++ b/espfit/utils/graphs.py @@ -518,12 +518,18 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): """Reshape conformation size. This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size. + + DGLError: `Expect all graphs to have the same schema on nodes["g"].data, but graph 1 got ...` + Here, we will modify the graphs so that each graph has the same number of conformations instead fo concatenating graphs into heterogenous graphs with the same number of conformations. This allows shuffling and mini-batching per graph (molecule). Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], and g.nodes['n1'].data['xyz'] will be updated. + Note that this was also intended to augment datasets with fewer molecular diversity but with more conformers from + RNA nucleosides and trinucleotides. + Parameters ---------- n_confs : int, default=50 @@ -535,6 +541,10 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): Returns ------- None + + TODO + ---- + * Better way to handle different graph size """ _logger.info(f'Reshape graph size') @@ -548,16 +558,17 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): new_graphs = [] n_confs_cache = n_confs for i, g in enumerate(self.graphs): + n_confs = n_confs_cache # restore n = g.nodes['n1'].data['xyz'].shape[1] if n == n_confs: - _logger.info(f"Mol #{i} ({n} conformers)") + _logger.info(f"Mol #{i} ({n} confs): Pass") new_graphs.append(g) elif n < n_confs: random.seed(self.random_seed) index_random = random.choices(range(0, n), k=n_confs-n) - _logger.info(f"Randomly select {len(index_random)} conformers from Mol #{i} ({n} conformers)") + _logger.info(f"Mol #{i} ({n} confs): Randomly select {len(index_random)} conformers") _g = copy.deepcopy(g) _g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1) @@ -572,10 +583,10 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False): # Get index for minimum energy conformer if include_min_energy_conf: index_min = [g.nodes['g'].data['u_ref'].argmin().item()] - n_confs = n_confs_cache - 1 - _logger.info(f"Shuffe Mol #{i} ({n} conformers) and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})") + n_confs = n_confs - 1 + _logger.info(f"Mol #{i} ({n} conformers): Shuffle and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})") else: - _logger.info(f"Shuffe Mol #{i} ({n} conformers) and split into {n_confs} conformers") + _logger.info(f"Mol #{i} ({n} conformers): Split into {n_confs} conformers") for j in range(n // n_confs + 1): _g = copy.deepcopy(g)