Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Committor updates #151

Merged
merged 32 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6acc928
Added smart derivatives util class and torch-scatter requirement
EnricoTrizio Jul 31, 2024
572d609
Added separation between boundary and variational dataset
EnricoTrizio Jul 31, 2024
2d383dd
Added smart derivatives to tests and inits
EnricoTrizio Jul 31, 2024
2644b06
Imrpoved pairwise distances for a list of pairs
EnricoTrizio Jul 31, 2024
a78695f
Added utils to all import
EnricoTrizio Jul 31, 2024
56687e7
removed wrong backup file
EnricoTrizio Jul 31, 2024
c243466
added utils to all import
EnricoTrizio Jul 31, 2024
9b54143
Added note on datamodule settings
EnricoTrizio Aug 2, 2024
3a9b394
Added colors kwarg option to plot isolines
EnricoTrizio Aug 12, 2024
b20a7f7
Made the committor tutorial more tutorial
EnricoTrizio Sep 20, 2024
b49fcd4
First draft committor notebook with smart derivatives
EnricoTrizio Sep 20, 2024
6f3c01e
Added doc to compute descriptors derivatives utils
EnricoTrizio Sep 20, 2024
4e08112
Merge branch 'main' of https://github.com/luigibonati/mlcolvar into s…
EnricoTrizio Sep 23, 2024
9135c72
fixed test-env torch-scatter requirement
EnricoTrizio Sep 24, 2024
a5af03b
added pyg channel to test conda env
EnricoTrizio Sep 24, 2024
aaf201c
use pip to instaall torch-scatter, remove pyg from channels
EnricoTrizio Sep 24, 2024
e62ec55
fixed typings
EnricoTrizio Sep 24, 2024
d11fc13
Fixed typings
EnricoTrizio Sep 24, 2024
13efaff
Updated committor notebook
EnricoTrizio Sep 25, 2024
4bf0c4d
Added checks and warnings
EnricoTrizio Sep 25, 2024
d10d5d4
Reduce number of epochs for testing
EnricoTrizio Sep 25, 2024
eb51129
Added checkpoint to tutorial
EnricoTrizio Sep 25, 2024
4f3ae2f
Fixed typos
EnricoTrizio Sep 25, 2024
6dc6f44
Updated docs tutorial structure
EnricoTrizio Sep 25, 2024
540d100
changed source for committor tutorial
EnricoTrizio Sep 25, 2024
459a18a
Added wrapper function for smart derivatives
EnricoTrizio Sep 26, 2024
a67f9a4
Imporved doc
EnricoTrizio Sep 26, 2024
53a1651
Compacted distances code
EnricoTrizio Sep 26, 2024
d30a034
Fixed typo
EnricoTrizio Sep 26, 2024
6c36fb7
Updated micromamba
EnricoTrizio Sep 26, 2024
ac20551
Force older mamba
EnricoTrizio Sep 26, 2024
65f9b6e
Added reference for new micromamba github action
EnricoTrizio Sep 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ jobs:
df -h
ulimit -a

# More info on options: https://github.com/mamba-org/provision-with-micromamba#migration-to-setup-micromamba%60
- uses: mamba-org/setup-micromamba@main
# More info on options: https://github.com/mamba-org/setup-micromamba
- uses: mamba-org/setup-micromamba@v1
with:
micromamba-version: '1.5.10-0'
environment-file: devtools/conda-envs/test_env.yaml
environment-name: test
# channels: conda-forge,defaults
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ dependencies:
- pip:
- KDEpy
- nbmake
- torch-scatter

1,029 changes: 880 additions & 149 deletions docs/notebooks/examples/ex_committor.ipynb

Large diffs are not rendered by default.

452 changes: 452 additions & 0 deletions docs/notebooks/tutorials/cvs_committor.ipynb

Large diffs are not rendered by default.

10,002 changes: 0 additions & 10,002 deletions docs/notebooks/tutorials/data/muller-brown/biased/committor/iter_1/COLVAR_A

This file was deleted.

10,002 changes: 0 additions & 10,002 deletions docs/notebooks/tutorials/data/muller-brown/biased/committor/iter_1/COLVAR_B

This file was deleted.

10,002 changes: 0 additions & 10,002 deletions docs/notebooks/tutorials/data/muller-brown/biased/committor/iter_2/COLVAR_A

This file was deleted.

10,002 changes: 0 additions & 10,002 deletions docs/notebooks/tutorials/data/muller-brown/biased/committor/iter_2/COLVAR_B

This file was deleted.

8 changes: 7 additions & 1 deletion docs/tutorials_cvs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,10 @@ Methods for CVs optimization
:caption: Time-informed setting
:maxdepth: 1

notebooks/tutorials/cvs_DeepTICA.ipynb
notebooks/tutorials/cvs_DeepTICA.ipynb

.. toctree::
:caption: Committor-based setting
:maxdepth: 1

notebooks/tutorials/cvs_committor.ipynb
6 changes: 4 additions & 2 deletions mlcolvar/core/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
"FisherDiscriminantLoss",
"fisher_discriminant_loss",
"CommittorLoss",
"committor_loss"
"committor_loss",
"SmartDerivatives",
"compute_descriptors_derivatives"
]

from .mse import MSELoss, mse_loss
Expand All @@ -21,4 +23,4 @@
from .elbo import ELBOGaussiansLoss, elbo_gaussians_loss
from .autocorrelation import AutocorrelationLoss, autocorrelation_loss
from .fisher import FisherDiscriminantLoss, fisher_discriminant_loss
from .committor_loss import CommittorLoss, committor_loss
from .committor_loss import CommittorLoss, committor_loss, SmartDerivatives, compute_descriptors_derivatives
329 changes: 316 additions & 13 deletions mlcolvar/core/loss/committor_loss.py

Large diffs are not rendered by default.

41 changes: 30 additions & 11 deletions mlcolvar/core/transform/descriptors/pairwise_distances.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from mlcolvar.core.transform import Transform
from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix
from mlcolvar.core.transform.descriptors.utils import compute_distances_matrix, compute_distances_pairs

from typing import Union

Expand Down Expand Up @@ -54,22 +54,32 @@ def __init__(self,
self.slicing_pairs = slicing_pairs

def compute_pairwise_distances(self, pos):
dist = compute_distances_matrix(pos=pos,
n_atoms=self.n_atoms,
PBC=self.PBC,
cell=self.cell,
scaled_coords=self.scaled_coords)
batch_size = dist.shape[0]
# if we compute all distances we use the matrix trick
if self.slicing_pairs is None:
dist = compute_distances_matrix(pos=pos,
n_atoms=self.n_atoms,
PBC=self.PBC,
cell=self.cell,
scaled_coords=self.scaled_coords)
batch_size = dist.shape[0]
device = pos.device
# mask out diagonal elements
aux_mask = torch.ones_like(dist, device=device) - torch.eye(dist.shape[-1], device=device)
# keep upper triangular part to avoid duplicates
unique = aux_mask.triu().nonzero(as_tuple=True)
pairwise_distances = dist[unique].reshape((batch_size, -1))
return pairwise_distances

# if we only compute a few selected distances we do that explicitly
else:
return dist[:, self.slicing_pairs[:, 0], self.slicing_pairs[:, 1]]
dist = compute_distances_pairs(pos=pos,
n_atoms=self.n_atoms,
PBC=self.PBC,
cell=self.cell,
scaled_coords=self.scaled_coords,
slicing_pairs=self.slicing_pairs)

pairwise_distances = dist
return pairwise_distances


def forward(self, x: torch.Tensor):
Expand All @@ -82,11 +92,14 @@ def test_pairwise_distances():
-0.0553, 1.4940, 1.4990, -0.2403, 1.4780, -1.4173, -0.3363, -1.4243,
-1.4093, -0.4293, 1.3530, -1.4313, -0.4183, 1.3060, 1.4750, -0.4333,
1.2970, -1.3233, -0.4643, 1.1670, -1.3253, -0.5354]])
pos_abs.requires_grad = True

cell = torch.Tensor([3.0233])

pos_scaled = pos_abs / cell
pos_scaled = torch.clone(pos_abs) / cell

pos_abs.requires_grad = True
pos_scaled.requires_grad = True


ref_distances = torch.Tensor([[0.1521, 0.2335, 0.2412, 0.3798, 0.4733, 0.4649, 0.4575, 0.5741, 0.6815,
0.1220, 0.1323, 0.2495, 0.3407, 0.3627, 0.3919, 0.4634, 0.5885, 0.2280,
Expand Down Expand Up @@ -114,6 +127,12 @@ def test_pairwise_distances():
assert(torch.allclose(out, ref_distances, atol=1e-3))
out.sum().backward()

# PBC and scaled coords slicing
model = PairwiseDistances(n_atoms=10, PBC=True, cell=cell, scaled_coords=True, slicing_pairs=[[0, 1], [0, 2]])
out = model(pos_scaled)
assert(torch.allclose(out, ref_distances[:, [0, 1]], atol=1e-3))
out.sum().backward()


if __name__ == "__main__":
test_pairwise_distances()
117 changes: 102 additions & 15 deletions mlcolvar/core/transform/descriptors/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from typing import Union
from typing import Union, List, Tuple

def sanitize_positions_shape(pos: torch.Tensor,
n_atoms: int):
Expand Down Expand Up @@ -60,6 +60,23 @@ def sanitize_cell_shape(cell: Union[float, torch.Tensor, list]):

return cell

def _apply_pbc_distances(dist_components, pbc_cell):
shifts = torch.zeros_like(dist_components)
# avoid loop if cell is cubic
if pbc_cell[0]==pbc_cell[1] and pbc_cell[1]==pbc_cell[2]:
shifts = torch.div(dist_components, pbc_cell[0]/2, rounding_mode='trunc')
shifts = torch.div(shifts + 1*torch.sign(shifts), 2, rounding_mode='trunc' )*pbc_cell[0]

else:
# loop over dimensions of the pbc_cell
for d in range(3):
shifts[:, d, :, :] = torch.div(dist_components[:, d, :, :], pbc_cell[d]/2, rounding_mode='trunc')
shifts[:, d, :, :] = torch.div(shifts[:, d, :, :] + 1*torch.sign(shifts[:, d, :, :]), 2, rounding_mode='trunc' )*pbc_cell[d]/2

# apply shifts
dist_components = dist_components - shifts
return dist_components

def compute_distances_matrix(pos: torch.Tensor,
n_atoms: int,
PBC: bool,
Expand Down Expand Up @@ -122,20 +139,7 @@ def compute_distances_matrix(pos: torch.Tensor,

# get PBC shifts
if PBC:
shifts = torch.zeros_like(dist_components)
# avoid loop if cell is cubic
if pbc_cell[0]==pbc_cell[1] and pbc_cell[1]==pbc_cell[2]:
shifts = torch.div(dist_components, pbc_cell[0]/2, rounding_mode='trunc')
shifts = torch.div(shifts + 1*torch.sign(shifts), 2, rounding_mode='trunc' )*pbc_cell[0]

else:
# loop over dimensions of the pbc_cell
for d in range(3):
shifts[:, d, :, :] = torch.div(dist_components[:, d, :, :], pbc_cell[d]/2, rounding_mode='trunc')
shifts[:, d, :, :] = torch.div(shifts[:, d, :, :] + 1*torch.sign(shifts[:, d, :, :]), 2, rounding_mode='trunc' )*pbc_cell[d]/2

# apply shifts
dist_components = dist_components - shifts
dist_components = _apply_pbc_distances(dist_components=dist_components, pbc_cell=pbc_cell)

# if we used scaled coords we need to get back to real distances
if scaled_coords:
Expand All @@ -153,6 +157,89 @@ def compute_distances_matrix(pos: torch.Tensor,
dist[mask_diag] = torch.sqrt( dist[mask_diag])
return dist


def compute_distances_pairs(pos: torch.Tensor,
n_atoms: int,
PBC: bool,
cell: Union[float, list],
slicing_pairs: List[Tuple[int, int]],
vector: bool = False,
scaled_coords: bool = False,
) -> torch.Tensor:
"""Compute the pairwise distances for a list of atom pairs from batches of atomic coordinates.
Optionally can return the vector distances.

Parameters
----------
pos : torch.Tensor
Positions of the atoms, they can be given with shapes:
- Shape: (n_batch (optional), n_atoms * 3), i.e [ [x1,y1,z1, x2,y2,z2, .... xn,yn,zn] ]
- Shape: (n_batch (optional), n_atoms, 3), i.e [ [ [x1,y1,z1], [x2,y2,z2], .... [xn,yn,zn] ] ]
n_atoms : int
Number of atoms
PBC : bool
Switch for Periodic Boundary Conditions use
cell : Union[float, list]
Dimensions of the real cell, orthorombic-like cells only, by default False
slicing_pairs : list[tuple[int, int]]
List of the indeces of the pairs for which to compute the distances
vector : bool, optional
Switch to return vector distances
scaled_coords : bool, optional
Switch for coordinates scaled on cell's vectors use

Returns
-------
torch.Tensor
Pairwise distances for the selected atom pairs
Enabling `vector=True` can return the vector components of the distances
"""
# ======================= CHECKS =======================
pos, batch_size = sanitize_positions_shape(pos, n_atoms)
cell = sanitize_cell_shape(cell)

# Set which cell to be used for PBC
if scaled_coords:
pbc_cell = torch.Tensor([1., 1., 1.])
else:
pbc_cell = cell

_device = pos.device
cell = cell.to(_device)

# ======================= COMPUTE =======================
pos = torch.reshape(pos, (batch_size, n_atoms, 3)) # this preserves the order when the pos are passed as a list
pos = torch.transpose(pos, 1, 2)
pos = pos.reshape((batch_size, 3, n_atoms))

# Initialize tensor to hold distances
if vector:
distances = torch.zeros((batch_size, len(slicing_pairs), 3), device=_device)
else:
distances = torch.zeros((batch_size, len(slicing_pairs)), device=_device)

# we create two tensors for starting and ending positions
pos_a = pos[:, :, slicing_pairs[:, 0]]
pos_b = pos[:, :, slicing_pairs[:, 1]]

# compute the distance components for all the pairs
dist_components = pos_b - pos_a

# get PBC shifts
if PBC:
dist_components = _apply_pbc_distances(dist_components=dist_components, pbc_cell=pbc_cell)

# if we used scaled coords we need to get back to real distances
if scaled_coords:
dist_components = torch.einsum('bij,i->bij', dist_components, cell)

if vector:
distances = dist_components
else:
distances = torch.sqrt(torch.sum(dist_components ** 2, dim=1))

return distances

def apply_cutoff(x: torch.Tensor,
cutoff: float,
mode: str = 'continuous',
Expand Down
27 changes: 22 additions & 5 deletions mlcolvar/cvs/committor/committor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def __init__(
gamma: float = 10000,
delta_f: float = 0,
cell: float = None,
separate_boundary_dataset : bool = True,
descriptors_derivatives : torch.nn.Module = None,
options: dict = None,
**kwargs,
):
Expand All @@ -62,6 +64,11 @@ def __init__(
State B is supposed to be higher in energy.
cell : float, optional
CUBIC cell size length, used to scale the positions from reduce coordinates to real coordinates, by default None
separate_boundary_dataset : bool, optional
Switch to exculde boundary condition labeled data from the variational loss, by default True
descriptors_derivatives : torch.nn.Module, optional
`SmartDerivatives` object to save memory and time when using descriptors.
See also mlcolvar.core.loss.committor_loss.SmartDerivatives
options : dict[str, Any], optional
Options for the building blocks of the model, by default {}.
Available blocks: ['nn'] .
Expand All @@ -73,7 +80,9 @@ def __init__(
alpha=alpha,
gamma=gamma,
delta_f=delta_f,
cell=cell
cell=cell,
separate_boundary_dataset=separate_boundary_dataset,
descriptors_derivatives=descriptors_derivatives
)

# ======= OPTIONS =======
Expand Down Expand Up @@ -132,14 +141,15 @@ def test_committor():

# create two fake atoms and use their fake positions
atomic_masses = initialize_committor_masses(atom_types=[0,1], masses=[15.999, 1.008])
model = Committor(layers=[6, 4, 2, 1], mass=atomic_masses, alpha=1e-1, delta_f=0)
# create dataset
samples = 50
X = torch.randn((2*samples, 6))
X = torch.randn((4*samples, 6))

# create labels
y = torch.zeros(X.shape[0])
y[samples:] += 1
y[int(2*samples):] += 1
y[int(3*samples):] += 1

# create weights
w = torch.ones(X.shape[0])
Expand All @@ -149,12 +159,19 @@ def test_committor():

# train model
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)

# dataset separation
model = Committor(layers=[6, 4, 2, 1], mass=atomic_masses, alpha=1e-1, delta_f=0)
trainer.fit(model, datamodule)

model(X).sum().backward()

bias_model = KolmogorovBias(input_model=model, beta=1, epsilon=1e-6, lambd=1)
bias_model(X)

# naive whole dataset
trainer = lightning.Trainer(max_epochs=5, logger=None, enable_checkpointing=False, limit_val_batches=0, num_sanity_val_steps=0)
model = Committor(layers=[6, 4, 2, 1], mass=atomic_masses, alpha=1e-1, delta_f=0, separate_boundary_dataset=False)
trainer.fit(model, datamodule)
model(X).sum().backward()

if __name__ == "__main__":
test_committor()
Loading
Loading