Skip to content

Commit

Permalink
Merge pull request #25 from choderalab/fix-boltzmann-mean
Browse files Browse the repository at this point in the history
Fix Boltzmann mean
  • Loading branch information
kaminow authored Aug 17, 2023
2 parents 07ea82b + 6a17958 commit aafd56b
Showing 1 changed file with 77 additions and 31 deletions.
108 changes: 77 additions & 31 deletions mtenn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from itertools import permutations
import os
import torch
from typing import Optional


class Model(torch.nn.Module):
Expand Down Expand Up @@ -328,55 +329,98 @@ def forward(self, predictions: torch.Tensor):
return torch.mean(predictions)


class BoltzmannCombination(Combination):
class MaxCombination(Combination):
"""
Combine a list of deltaG predictions according to their Boltzmann weight.
Approximate max/min of the predictions using the LogSumExp function for smoothness.
"""

def __init__(self):
super(BoltzmannCombination, self).__init__()
def __init__(self, neg=True, scale=1000.0):
"""
Parameters
----------
neg : bool, default=True
Negate the predictions before calculating the LSE, effectively finding
the min. Preds are negated again before being returned
scale : float, default=1000.0
Fixed positive value to scale predictions by before taking the LSE. This
tightens the bounds of the LSE approximation
"""
super(MaxCombination, self).__init__()

self.neg = -1 * neg
self.scale = scale

from simtk.unit import (
BOLTZMANN_CONSTANT_kB as kB,
elementary_charge,
coulomb,
def forward(self, predictions: torch.Tensor):
return (
self.neg
* torch.logsumexp(self.neg * self.scale * predictions, dim=0)
/ self.scale
)

## Convert kB to eV (calibrate to SchNet predictions)
electron_volt = elementary_charge.conversion_factor_to(coulomb)

self.kT = (kB / electron_volt * 298.0)._value
class BoltzmannCombination(Combination):
"""
Combine a list of deltaG predictions according to their Boltzmann weight. Use LSE
approximation of min energy to improve numerical stability. Treat energy in implicit
kT units.
"""

def __init__(self):
super(BoltzmannCombination, self).__init__()

def forward(self, predictions: torch.Tensor):
return -self.kT * torch.logsumexp(-predictions, dim=0)
# First calculate LSE (no scale here bc math)
lse = torch.logsumexp(-predictions, dim=0)
# Calculate Boltzmann weights for each prediction
w = torch.exp(-predictions - lse)

return torch.dot(w, predictions)


class PIC50Readout(Readout):
"""
Readout implementation to convert delta G values to pIC50 values.
Readout implementation to convert delta G values to pIC50 values. This new
implementation assumes implicit energy units, WHICH WILL INVALIDATE MODELS TRAINED
PRIOR TO v0.3.0.
Assuming implicit energy units:
deltaG = ln(Ki)
Ki = exp(deltaG)
Using the Cheng-Prusoff equation:
Ki = IC50 / (1 + [S]/Km)
exp(deltaG) = IC50 / (1 + [S]/Km)
IC50 = exp(deltaG) * (1 + [S]/Km)
pIC50 = -log10(exp(deltaG) * (1 + [S]/Km))
pIC50 = -log10(exp(deltaG)) - log10(1 + [S]/Km)
pIC50 = -ln(exp(deltaG))/ln(10) - log10(1 + [S]/Km)
pIC50 = -deltaG/ln(10) - log10(1 + [S]/Km)
Estimating Ki as the IC50 value:
Ki = IC50
IC50 = exp(deltaG)
pIC50 = -log10(exp(deltaG))
pIC50 = -ln(exp(deltaG))/ln(10)
pIC50 = -deltaG/ln(10)
"""

def __init__(self, T=298.0):
def __init__(self, substrate: Optional[float] = None, Km: Optional[float] = None):
"""
Initialize conversion with specified T (assume 298 K).
Initialize conversion with specified substrate concentration and Km. If either
is left blank, the IC50 approximation will be used.
Parameters
----------
T : float, default=298
Temperature for conversion.
substrate : float, optional
Substrate concentration for use in the Cheng-Prusoff equation. Assumed to be
in the same units as Km
Km : float, optional
Km value for use in the Cheng-Prusoff equation. Assumed to be in the same
units as substrate
"""
super(PIC50Readout, self).__init__()

from simtk.unit import (
BOLTZMANN_CONSTANT_kB as kB,
elementary_charge,
coulomb,
)

## Convert kB to eV (calibrate to SchNet predictions)
electron_volt = elementary_charge.conversion_factor_to(coulomb)

self.kT = (kB / electron_volt * T)._value
if substrate and Km:
self.cp_val = 1 + substrate / Km
else:
self.cp_val = None

def forward(self, delta_g):
"""
Expand All @@ -392,7 +436,9 @@ def forward(self, delta_g):
float
Calculated pIC50 value.
"""
## IC50 value = exp(dG/kT) => pic50 = -log10(exp(dg/kT))
## Rearrange a bit more to avoid disappearing floats:
## pic50 = -dg/kT / ln(10)
return -delta_g / self.kT / torch.log(torch.tensor(10, dtype=delta_g.dtype))
pic50 = -delta_g / torch.log(torch.tensor(10, dtype=delta_g.dtype))
# Using Cheng-Prusoff
if self.cp_val:
pic50 -= torch.log10(torch.tensor(self.cp_val, dtype=delta_g.dtype))

return pic50

0 comments on commit aafd56b

Please sign in to comment.