Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/use u qm to reshape graphs #7

Merged
merged 6 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,18 @@ Infrastruture to train espaloma with experimental observables
### Installation
>mamba create -n espfit python=3.11
>mamba install -c conda-forge espaloma=0.3.2
>#uninstall openff-toolkit and install a customized version to support dgl graphs created using openff-toolkit=0.10.6
>conda uninstall --force openff-toolkit
>pip install git+https://github.com/kntkb/openff-toolkit.git@7e9d0225782ef723083407a1cbf1f4f70631f934
>#install openeye-toolkit
>mamba install openeye-toolkits -c openeye
>#uninstall openmmforcefields if < 0.12.0
>conda uninstall --force openmmforcefields
>#use pip instead of mamba to avoid dependency issues with ambertools and python
>pip install git+https://github.com/openmm/[email protected]
>#install openmmtools
>mamba install openmmtools
>#install barnaba
>mamba install barnaba

#### Notes
- `openff-toolkit` is re-installed with a customized version to support dgl graphs created using `openff-toolkit=0.10.6`
- `openmmforcefields` is reinstalled if the version is `<0.12.0` using pip to avoid dependency issues with `ambertools` and `python`. espaloma functionalities are better supported after `>=0.12.0`.


### Quick Usage
```python
Expand Down
56 changes: 52 additions & 4 deletions espfit/app/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def report_loss(self, epoch, loss_dict):

log_file_path = os.path.join(self.output_directory_path, 'reporter.log')
df_new = pd.DataFrame.from_dict(loss_dict, orient='index').T
df_new = df_new.mul(100) # Multiple each loss component by 100
df_new = df_new.mul(100) # Multiple each loss component by 100. Is this large enough?
df_new.insert(0, 'epoch', epoch)

if os.path.exists(log_file_path):
Expand Down Expand Up @@ -455,7 +455,14 @@ def train_sampler(self, sampler_patience=800, neff_threshold=0.2, sampler_weight
with torch.autograd.set_detect_anomaly(True):
for i in range(self.restart_epoch, self.epochs):
epoch = i + 1 # Start from 1 (not zero-indexing)


"""
# torch.cuda.OutOfMemoryError: CUDA out of memory.
# Tried to allocate 80.00 MiB (GPU 0; 10.75 GiB total capacity;
# 9.76 GiB already allocated; 7.62 MiB free; 10.40 GiB reserved in total by PyTorch)
# If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.
# See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

loss = torch.tensor(0.0)
if torch.cuda.is_available():
loss = loss.cuda("cuda:0")
Expand Down Expand Up @@ -496,7 +503,48 @@ def train_sampler(self, sampler_patience=800, neff_threshold=0.2, sampler_weight
# Back propagate
loss.backward()
optimizer.step()

"""

# Gradient accumulation
accumulation_steps = len(ds_tr_loader)
for g in ds_tr_loader:
optimizer.zero_grad()
if torch.cuda.is_available():
g = g.to("cuda:0")
g.nodes["n1"].data["xyz"].requires_grad = True

loss, loss_dict = self.net(g)
loss = loss/accumulation_steps
loss.backward()

if epoch > self.sampler_patience:
# Save checkpoint as local model (net.pt)
# `neff_min` is -1 if SamplerReweight.samplers is None
samplers = self._setup_local_samplers(epoch, net_copy, debug)
neff_min = SamplerReweight.get_effective_sample_size(temporary_samplers=samplers)

# If effective sample size is below threshold, update SamplerReweight.samplers and re-run simulaton
if neff_min < self.neff_threshold:
_logger.info(f'Minimum effective sample size ({neff_min:.3f}) below threshold ({self.neff_threshold})')
SamplerReweight.samplers = samplers
SamplerReweight.run()
del samplers

# Compute sampler loss
loss_list = SamplerReweight.compute_loss() # list of torch.tensor
for sampler_index, sampler_loss in enumerate(loss_list):
sampler = SamplerReweight.samplers[sampler_index]
loss += sampler_loss * sampler_weight
loss_dict[f'{sampler.target_name}'] = sampler_loss.item()
loss.backward()
loss_dict['neff'] = neff_min

loss_dict['loss'] = loss.item()
self.report_loss(epoch, loss_dict)

# Update
optimizer.step()

if epoch % self.checkpoint_frequency == 0:
# Note: returned loss is a joint loss of different units.
#_loss = HARTREE_TO_KCALPERMOL * loss.pow(0.5).item()
Expand Down Expand Up @@ -577,7 +625,7 @@ def _save_local_model(self, epoch, net_copy):
_logger.info(f'Save ckpt{epoch}.pt as temporary espaloma model (net.pt)')
self._save_checkpoint(epoch)
local_model = os.path.join(self.output_directory_path, f"ckpt{epoch}.pt")
self.save_model(net=net_copy, best_model=local_model, model_name=f"net.pt", output_directory_path=self.output_directory_path)
self.save_model(net=net_copy, checkpoint_file=local_model, output_model=f"net.pt", output_directory_path=self.output_directory_path)


def _setup_local_samplers(self, epoch, net_copy, debug):
Expand Down
28 changes: 22 additions & 6 deletions espfit/utils/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CustomGraphDataset(GraphDataset):
compute_baseline_energy_force(forcefield_list=['openff-2.1.0']):
Compute energies and forces using other force fields.

reshape_conformation_size(n_confs=50, include_min_energy_conf=False):
reshape_conformation_size(n_confs=50, include_min_energy_conf=False, keyname='u_ref'):
Reshape conformation size.

compute_relative_energy():
Expand Down Expand Up @@ -514,7 +514,7 @@ def compute_relative_energy(self):
del new_graphs


def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False, keyname='u_ref'):
"""Reshape conformation size.

This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size.
Expand All @@ -539,6 +539,11 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
include_min_energy_conf : boolean, default=False
If True, then minimum energy conformer will be included for all split graphs.

keyname : str, default='u_ref'
Key name to be used to define the energy minima. This is usually `u_ref` or `u_qm`.
Note that depending on how the dataset was prepared, nonbonded energies could be subtracted from `u_ref`,
whereas `u_qm` could be the raw QM energies.

Returns
-------
None
Expand All @@ -553,8 +558,9 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
import copy
import torch

# Remove node features that are not used during training
self._remove_node_features()
# Check if keyname is specified
if include_min_energy_conf == True and keyname not in ['u_ref', 'u_qm']:
raise Exception(f'Key name {keyname} not supported. Supported keynames are u_ref and u_qm')

new_graphs = []
n_confs_cache = n_confs
Expand Down Expand Up @@ -584,7 +590,14 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

# Get index for minimum energy conformer
if include_min_energy_conf:
index_min = [g.nodes['g'].data['u_ref'].argmin().item()]
index_min = [g.nodes['g'].data[keyname].argmin().item()]

# DEBUG PURPOSE
#_index_min_uref = [g.nodes['g'].data['u_ref'].argmin().item()]
#_index_min_uqm = [g.nodes['g'].data['u_qm'].argmin().item()]
#_logger.info(f'(u_ref:{_index_min_uref[0]} and u_qm:{_index_min_uqm[0]})')
#_logger.info(f'Index for minima energy conformer {keyname}: {index_min[0]}')

n_confs = n_confs - 1
_logger.info(f"Mol #{i} ({n} conformers): Shuffle and split into {n_confs} conformers and add minimum energy conformer (index #{index_min[0]})")
else:
Expand All @@ -603,7 +616,7 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers and add minimum energy conformer")
else:
_logger.debug(f"Iteration {j}: Randomly select {len(index_random)} conformers")

_g.nodes["g"].data["u_ref"] = torch.cat((_g.nodes['g'].data['u_ref'][:, index], _g.nodes['g'].data['u_ref'][:, index_random]), dim=-1)
_g.nodes["g"].data["u_ref_relative"] = torch.cat((_g.nodes['g'].data['u_ref_relative'][:, index], _g.nodes['g'].data['u_ref_relative'][:, index_random]), dim=-1)
_g.nodes["n1"].data["xyz"] = torch.cat((_g.nodes['n1'].data['xyz'][:, index, :], _g.nodes['n1'].data['xyz'][:, index_random, :]), dim=1)
Expand All @@ -628,6 +641,9 @@ def reshape_conformation_size(self, n_confs=50, include_min_energy_conf=False):

# Update in place
self.graphs = new_graphs
# Remove node features that are not used during training
self._remove_node_features()

del new_graphs


Expand Down
Loading