Skip to content

Commit

Permalink
fix n_confs from changing in reshape_conformation_size
Browse files Browse the repository at this point in the history
  • Loading branch information
kntkb committed Mar 13, 2024
1 parent 0284c10 commit 79152f9
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions espfit/utils/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,18 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
"""Reshape conformation size.
This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size.
DGLError: `Expect all graphs to have the same schema on nodes["g"].data, but graph 1 got ...`
Here, we will modify the graphs so that each graph has the same number of conformations instead fo concatenating
graphs into heterogenous graphs with the same number of conformations. This allows shuffling and mini-batching per
graph (molecule).
Only g.nodes['g'].data['u_ref'], g.nodes['g'].data['u_ref_relative'], and g.nodes['n1'].data['xyz'] will be updated.
Note that this was also intended to augment datasets with fewer molecular diversity but with more conformers from
RNA nucleosides and trinucleotides.
Parameters
----------
n_confs : int, default=50
Expand All @@ -535,6 +541,10 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
Returns
-------
None
TODO
----
* Better way to handle different graph size
"""
_logger.info(f'Reshape graph size')

Expand All @@ -548,16 +558,17 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
new_graphs = []
n_confs_cache = n_confs
for i, g in enumerate(self.graphs):
n_confs = n_confs_cache # restore
n = g.nodes['n1'].data['xyz'].shape[1]

if n == n_confs:
_logger.info(f"Mol #{i} ({n} conformers)")
_logger.info(f"Mol #{i} ({n} confs): Pass")
new_graphs.append(g)

elif n < n_confs:
random.seed(self.random_seed)
index_random = random.choices(range(0, n), k=n_confs-n)
_logger.info(f"Randomly select {len(index_random)} conformers from Mol #{i} ({n} conformers)")
_logger.info(f"Mol #{i} ({n} confs): Randomly select {len(index_random)} conformers")

_g = copy.deepcopy(g)
_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
Expand All @@ -572,10 +583,10 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
# Get index for minimum energy conformer
if include_min_energy_conf:
index_min = [g.nodes['g'].data['u_ref'].argmin().item()]
n_confs = n_confs_cache - 1
_logger.info(f"Shuffe Mol #{i} ({n} conformers) and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})")
n_confs = n_confs - 1
_logger.info(f"Mol #{i} ({n} conformers): Shuffle and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})")
else:
_logger.info(f"Shuffe Mol #{i} ({n} conformers) and split into {n_confs} conformers")
_logger.info(f"Mol #{i} ({n} conformers): Split into {n_confs} conformers")

for j in range(n // n_confs + 1):
_g = copy.deepcopy(g)
Expand Down

0 comments on commit 79152f9

Please sign in to comment.