Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] biophysical_with_ic50 #113

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 2 additions & 39 deletions pinot/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,49 +237,12 @@ def condition(self, g, *args, **kwargs):
if 'sampler' in kwargs:
sampler = kwargs.pop('sampler')

if 'n_samples' in kwargs:
n_samples = kwargs.pop('n_samples')

if self.has_exact_gp is True:
h_last = self.representation(self.g_last)
kwargs = {**{"x_tr": h_last, "y_tr": self.y_last}, **kwargs}

if sampler is None:
return self._condition(h, *args, **kwargs)

if not hasattr(sampler, "sample_params"):
return self._condition(h, *args, **kwargs)

# initialize a list of distributions
distributions = []

for _ in range(n_samples):
if sampler is not None and hasattr(sampler, "sample_params"):
sampler.sample_params()
distributions.append(self._condition(g, *args, **kwargs))

# get the parameter of these distributions
# NOTE: this is not necessarily the most efficienct solution
# since we don't know the memory footprint of
# torch.distributions
mus, sigmas = zip(
*[
(distribution.loc, distribution.scale)
for distribution in distributions
]
)

# concat parameters together
# (n_samples, batch_size, measurement_dimension)
mu = torch.stack(mus).cpu() # distribution no cuda
sigma = torch.stack(sigmas).cpu()
return self._condition(h, *args, **kwargs)

# construct the distribution
distribution = torch.distributions.normal.Normal(loc=mu, scale=sigma)

# make it mixture
distribution = torch.distributions.mixture_same_family.MixtureSameFamily(
torch.distributions.Categorical(torch.ones(mu.shape[0],)),
torch.distributions.Independent(distribution, 2),
)

return distribution
173 changes: 91 additions & 82 deletions pinot/regressors/biophysical_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import abc
import math
import numpy as np
from pinot.regressors.base_regressor import BaseRegressor

# =============================================================================
# MODULE CLASSES
# =============================================================================
class BiophysicalRegressor(torch.nn.Module):
class BiophysicalRegressor(BaseRegressor):
r""" Biophysically inspired model

Parameters
Expand All @@ -23,95 +24,103 @@ class BiophysicalRegressor(torch.nn.Module):

"""

def __init__(self, base_regressor=None, *args, **kwargs):
def __init__(self, base_regressor_class=None, *args, **kwargs):
super(BiophysicalRegressor, self).__init__()
self.base_regressor = base_regressor
# get the base regressor
self.base_regressor_class = base_regressor_class
self.base_regressor = base_regressor_class(
*args, **kwargs
)

# initialize measurement parameter
self.log_sigma_measurement = torch.nn.Parameter(torch.zeros(1))

def g(self, func_value=None, test_ligand_concentration=1e-3):
return 1 / (1 + torch.exp(-func_value) / test_ligand_concentration)
def _get_measurement(self, delta_g, concentration=1e-3):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yuanqing-wang : What units are you using here? It would be good to document them in the docstring!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's my code, I will add them

""" Translate ..math:: \Delta G to percentage inhibition.

Parameters
----------
delta_g : torch.Tensor, shape=(number_of_graphs, 1)
Binding free energy.

concentration : torch.Tensor, shape=(,) or (1, number_of_concentrations)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document units!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, I'm to blame.

Concentration of ligand.

Returns
-------
measurement : torch.Tensor, shape=(number_of_graphs, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's a typo in this equation.

If you want percent inhibition, you'll want something like

100 * (1 - 1.0 / (1.0 + concentration / torch.exp(delta_g))

where concentration must be in Molar units.

This means that when delta_g << 0 (a tight binder), we have 100% inhibition, and when delta_g ~ 0, we should have 0% inhibition.

Copy link
Collaborator

@karalets karalets Jul 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what I did was copy the function g from Josh's notebook ( https://gist.github.com/maxentile/fdc42d76c1125aff33f912fbc28edb1f )
Concentration is indeed assumed to be in Molar units in this code, and in our experiments.

Let me check out how your comments map to that code so I understand your comment better.

The original code in Josh's notebook is the following, with the math as explained beneat:

reference_concentration = 1 # molar
test_ligand_concentration = 1e-3

def g(f, c=test_ligand_concentration):
    return 1 / (1 + np.exp(-f) / c)

Screen Shot 2020-07-18 at 10 27 13 PM

@jchodera Can you explain the deviation from what you wrote above so I can parse my mistake?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, while we're at it, I only ever got Josh's notebook as a reference for this function. Where can I look up the published version of this so I can check the math independently? I should have asked for this earlier to debug my own code, I rushed this too much.

or (number_of_graphs, number_of_concentrations)
Percentage of inhibition.

"""
return 1.0 / (1.0 + torch.exp(-delta_g) / concentration)

def _condition_delta_g(self, x_te, *args, **kwargs):
""" Returns predictive distribution of binding free energy. """
return self.base_regressor.condition(x_te, *args, **kwargs)

def _condition_measurement(self, x_te=None, concentration=1e-3, delta_g_sample=None):
""" Returns predictive distribution of percentage of inhibtiion. """
# sample delta g if not specified
if delta_g_sample is None and x_te is not None:
delta_g_sample = self._condition_delta_g(x_te).rsample()

def condition(
self, h=None, test_ligand_concentration=1e-3, *args, **kwargs
):
distribution_base_regressor = self.base_regressor.condition(
h, *args, **kwargs
)
# we sample from the latent f to push things through the likelihood
# Note: if we do this,
# in order to get good estimates of LLK
# we may need to draw multiple samples
f_sample = distribution_base_regressor.rsample()
mu_m = self.g(
func_value=f_sample,
test_ligand_concentration=test_ligand_concentration,
)
sigma_m = torch.exp(self.log_sigma_measurement)
distribution_measurement = torch.distributions.normal.Normal(
loc=mu_m, scale=sigma_m
loc=self._get_measurement(delta_g_sample, concentration=concentration),
scale=self.log_sigma_measurement.exp()
)
# import pdb; pdb.set_trace()
return distribution_measurement

def loss(
self, h=None, y=None, test_ligand_concentration=None, *args, **kwargs
def _condition_ic_50(
self,
x_te=None,
delta_g_sample=None,
concentration_low=0.0,
concentration_high=1.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do these parameters mean? Can you document them in a docstring?

We definitely do not want a uniform distribution over IC50 as a prior. It should be uniform over log10(IC50), over a meaningful physical range. You can see what a sensible range is from this paper:
https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0061007

image

So maybe 0-14 in -log10(IC50).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks John, I think this was sth @yuanqing-wang just wanted to ask me stuff about that he hacked in just today.
@yuanqing-wang can you elaborate on your thinking? Ideally also with PR comments when you add a new function like this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @jchodera this is just a very immature draft I had in order to discuss some ideas. For the IC50 I was only thinking about how can we integrate that into the graphical model without using the Cheng-Prussoff equation. So I'm thinking about doing the titration curve here from concentration_low to concentration_high. But I realized that I can only do sampling this way but not get the distribution.

number_of_concentrations=1024,
):
# import pdb; pdb.set_trace()
distribution_measurement = self.condition(
h=h,
test_ligand_concentration=test_ligand_concentration,
*args,
**kwargs
)
loss_measurement = -distribution_measurement.log_prob(y).sum()
# import pdb; pdb.set_trace()
return loss_measurement
""" Returns predictive distribution of ic50 """
# sample delta g if not specified
if delta_g_sample is None and x_te is not None:
delta_g_sample = self._condition_delta_g(x_te).rsample()

def marginal_sample(
self, h=None, n_samples=100, test_ligand_concentration=1e-3, **kwargs
):
distribution_base_regressor = self.base_regressor.condition(
h, **kwargs
# get the possible array of concentration
concentration = torch.linspace(
start=concentration_low,
end=concentration_high,
steps=number_of_concentrations)

distribution_measurement = torch.distributions.normal.Normal(
loc=self._get_measurement(delta_g_sample, concentration=concentration),
scale=self.log_sigma_measurement.exp(),
)
samples_measurement = []
for ns in range(n_samples):
f_sample = distribution_base_regressor.rsample()
mu_m = self.g(
func_value=f_sample,
test_ligand_concentration=test_ligand_concentration,
)
sigma_m = torch.exp(self.log_sigma_measurement)
distribution_measurement = torch.distributions.normal.Normal(
loc=mu_m, scale=sigma_m
)
samples_measurement.append(distribution_measurement.sample())
return samples_measurement

def marginal_loss(
self,
h=None,
y=None,
test_ligand_concentration=1e-3,
n_samples=10,
**kwargs

# (number_of_graphs, number_of_concentrations)
measurement_sample = distribution_measurement.rsample()
measurement_sample_sorted = measurement_sample.sort(dimension=1)







def condition(
self,
*args,
output="measurement",
**kwargs,
):
"""
sample n_samples often from loss in order to get a better approximation
"""
distribution_base_regressor = self.base_regressor.condition(
h, **kwargs
)
marginal_loss_measurement = 0
for ns in range(n_samples):
f_sample = distribution_base_regressor.rsample()
mu_m = self.g(
func_value=f_sample,
test_ligand_concentration=test_ligand_concentration,
)
sigma_m = torch.exp(self.log_sigma_measurement)
distribution_measurement = torch.distributions.normal.Normal(
loc=mu_m, scale=sigma_m
)
marginal_loss_measurement += -distribution_measurement.log_prob(y)
marginal_loss_measurement /= n_samples
return marginal_loss_measurement
""" Public method for predictive distribution construction. """
if output == "measurement":
return self._condition_measurement(*args, **kwargs)

elif output == "delta_g":
return self._condition_delta_g(*args, **kwargs)

elif output == "ic50":
return self._condition_ic50(*args, **kwargs)

else:
raise RuntimeError('We only support condition measurement and delta g')


3 changes: 2 additions & 1 deletion pinot/regressors/tests/test_biophysical_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ def test_init():
def net():
import pinot
regressor = pinot.regressors.biophysical_regressor.BiophysicalRegressor(
base_regressor=pinot.regressors.NeuralNetworkRegressor(32),
base_regressor_class=pinot.regressors.NeuralNetworkRegressor,
in_features=32,
)

representation = pinot.representation.Sequential(
Expand Down