Skip to content

Commit

Permalink
First tests committor
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoTrizio committed Dec 6, 2024
1 parent 8797481 commit 031aa7f
Show file tree
Hide file tree
Showing 6 changed files with 50,482 additions and 78 deletions.
141 changes: 89 additions & 52 deletions mlcolvar/core/loss/committor_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# =============================================================================

import torch
from torch_scatter import scatter
from torch_scatter import scatter, scatter_sum
from typing import Union
import torch_geometric

# =============================================================================
# LOSS FUNCTIONS
Expand All @@ -26,19 +28,19 @@ class CommittorLoss(torch.nn.Module):
"""Compute a loss function based on Kolmogorov's variational principle for the determination of the committor function"""

def __init__(self,
mass: torch.Tensor,
atomic_masses: torch.Tensor,
alpha: float,
cell: float = None,
gamma: float = 10000,
delta_f: float = 0,
gamma: float = 10000.0,
delta_f: float = 0.0,
separate_boundary_dataset : bool = True,
descriptors_derivatives : torch.nn.Module = None
descriptors_derivatives : torch.nn.Module = None,
log_var: bool = True
):
"""Compute Kolmogorov's variational principle loss and impose boundary conditions on the metastable states
Parameters
----------
mass : torch.Tensor
atomic_masses : torch.Tensor
Atomic masses of the atoms in the system
alpha : float
Hyperparamer that scales the boundary conditions contribution to loss, i.e. alpha*(loss_bound_A + loss_bound_B)
Expand All @@ -57,44 +59,48 @@ def __init__(self,
"""
super().__init__()
self.register_buffer("mass", mass)
self.register_buffer("atomic_masses", atomic_masses)
self.alpha = alpha
self.cell = cell
self.gamma = gamma
self.delta_f = delta_f
self.descriptors_derivatives=descriptors_derivatives
self.descriptors_derivatives = descriptors_derivatives
self.separate_boundary_dataset = separate_boundary_dataset

def forward(
self, x: torch.Tensor, q: torch.Tensor, labels: torch.Tensor, w: torch.Tensor, create_graph: bool = True
self.log_var = log_var

def forward(self,
x: Union[torch.Tensor, torch_geometric.data.Batch],
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
create_graph: bool = True
) -> torch.Tensor:
return committor_loss(x=x,
q=q,
labels=labels,
w=w,
mass=self.mass,
atomic_masses=self.atomic_masses,
alpha=self.alpha,
gamma=self.gamma,
delta_f=self.delta_f,
create_graph=create_graph,
cell=self.cell,
separate_boundary_dataset=self.separate_boundary_dataset,
descriptors_derivatives=self.descriptors_derivatives
descriptors_derivatives=self.descriptors_derivatives,
log_var = self.log_var
)


def committor_loss(x: torch.Tensor,
q: torch.Tensor,
labels: torch.Tensor,
w: torch.Tensor,
mass: torch.Tensor,
atomic_masses: torch.Tensor,
alpha: float,
gamma: float = 10000,
delta_f: float = 0,
create_graph: bool = True,
cell: float = None,
separate_boundary_dataset : bool = True,
descriptors_derivatives : torch.nn.Module = None
separate_boundary_dataset: bool = True,
descriptors_derivatives: torch.nn.Module = None,
log_var: bool = True
):
"""Compute variational loss for committor optimization with boundary conditions
Expand All @@ -108,7 +114,7 @@ def committor_loss(x: torch.Tensor,
Labels for states, A and B states for boundary conditions
w : torch.Tensor
Reweighing factors to Boltzmann distribution. This should depend on the simulation in which the data were collected.
mass : torch.Tensor
atomic_masses : torch.Tensor
List of masses of all the atoms we are using, for each atom we need to repeat three times for x,y,z.
Can be created using `committor.utils.initialize_committor_masses`
alpha : float
Expand Down Expand Up @@ -138,63 +144,94 @@ def committor_loss(x: torch.Tensor,
The boundary loss term on basin A
gamma*alpha*loss_B : torch.Tensor
The boundary loss term on basin B
"""
"""
# ============================== SETUP ==============================
# check if input is graph
_is_graph_data = False
if isinstance(x, torch_geometric.data.batch.Batch):
batch = torch.clone(x['batch'])
node_types = torch.where(x['node_attrs'])[1]
x = x['positions']
_is_graph_data = True


# inherit right device
device = x.device
dtype = x.dtype

mass = mass.to(device)

# Create masks to access different states data
mask_A = torch.nonzero(labels.squeeze() == 0, as_tuple=True)
mask_B = torch.nonzero(labels.squeeze() == 1, as_tuple=True)
mask_A = torch.nonzero(labels.squeeze() == 0)
mask_B = torch.nonzero(labels.squeeze() == 1)
if separate_boundary_dataset:
mask_var = torch.nonzero(labels.squeeze() > 1, as_tuple=True)
if _is_graph_data:
# this needs to be on the batch index, not only the labels
mask_var = torch.nonzero(labels.squeeze() > 1)
aux = torch.where(mask_var)[0]
mask_var_batches = torch.isin(batch, aux)
mask_var_batches = (batch[mask_var_batches])
else:
mask_var = torch.nonzero(labels.squeeze() > 1)
mask_var_batches = mask_var
else:
mask_var = torch.ones(len(x), dtype=torch.bool)

# Update weights of basin B using the information on the delta_f
mask_var = torch.ones(len(x), dtype=torch.bool)
mask_var_batches = mask_var

# setup atomic masses
atomic_masses = atomic_masses.to(dtype).to(device)
# mass should have size [1, n_atoms*spatial_dims]
if _is_graph_data:

atomic_masses = atomic_masses[node_types[mask_var_batches]].unsqueeze(-1)

else:
atomic_masses = atomic_masses.unsqueeze(0)

# Update weights for bc confs using the information on the delta_f
delta_f = torch.Tensor([delta_f])
if delta_f < 0: # B higher in energy --> A-B < 0
# B higher in energy --> A-B < 0
if delta_f < 0:
w[mask_B] = w[mask_B] * torch.exp(delta_f.to(device))
elif delta_f > 0: # A higher in energy --> A-B > 0
# A higher in energy --> A-B > 0
elif delta_f > 0:
w[mask_A] = w[mask_A] * torch.exp(-delta_f.to(device))

###### VARIATIONAL PRINICIPLE LOSS ######
# weights should have size [n_batch, 1]
w = w.unsqueeze(-1)
# ============================== LOSS ==============================
# Each loss contribution is scaled by the number of samples

# We need the gradient of q(x)
# 1. VARIATIONAL LOSS
# Compute gradients of q(x) wrt x
grad_outputs = torch.ones_like(q[mask_var])
grad = torch.autograd.grad(q[mask_var], x, grad_outputs=grad_outputs, retain_graph=True, create_graph=create_graph)[0]
grad = grad[mask_var]

# TODO this fixes cell size issue
if cell is not None:
grad = grad / cell

grad = grad[mask_var_batches]
if descriptors_derivatives is not None:
# we use the precomputed derivatives from descriptors to pos
grad_square = descriptors_derivatives(grad)
else:
# we get the square of grad(q) and we multiply by the weight
grad_square = torch.pow(grad, 2)

# we sanitize the shapes of mass and weights tensors
# mass should have size [1, n_atoms*spatial_dims]
mass = mass.unsqueeze(0)
# weights should have size [n_batch, 1]
w = w.unsqueeze(-1)

grad_square = torch.sum((grad_square * (1/mass)), axis=1, keepdim=True)
grad_square = torch.sum((grad_square * (1/atomic_masses)), axis=1, keepdim=True)

if _is_graph_data:
# we need to sum on the right batch first
grad_square = scatter_sum(grad_square, mask_var_batches, dim=0)

grad_square = grad_square * w[mask_var]

# variational contribution to loss: we sum over the batch
loss_var = torch.mean(grad_square)
if False:
loss_var = loss_var.log()

Check warning

Code scanning / CodeQL

Unreachable code Warning

This statement is unreachable.


# 2. BOUNDARY LOSS
loss_A = torch.mean( torch.pow(q[mask_A], 2))
loss_B = torch.mean( torch.pow( (q[mask_B] - 1) , 2))

# boundary conditions
q_A = q[mask_A]
q_B = q[mask_B]
loss_A = torch.mean( torch.pow(q_A, 2))
loss_B = torch.mean( torch.pow( (q_B - 1) , 2))

# 3. TOTAL LOSS
loss = gamma*( loss_var + alpha*(loss_A + loss_B) )

# TODO maybe there is no need to detach them for logging
Expand Down
Loading

0 comments on commit 031aa7f

Please sign in to comment.