-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] biophysical_with_ic50 #113
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,11 +6,12 @@ | |
import abc | ||
import math | ||
import numpy as np | ||
from pinot.regressors.base_regressor import BaseRegressor | ||
|
||
# ============================================================================= | ||
# MODULE CLASSES | ||
# ============================================================================= | ||
class BiophysicalRegressor(torch.nn.Module): | ||
class BiophysicalRegressor(BaseRegressor): | ||
r""" Biophysically inspired model | ||
|
||
Parameters | ||
|
@@ -23,95 +24,103 @@ class BiophysicalRegressor(torch.nn.Module): | |
|
||
""" | ||
|
||
def __init__(self, base_regressor=None, *args, **kwargs): | ||
def __init__(self, base_regressor_class=None, *args, **kwargs): | ||
super(BiophysicalRegressor, self).__init__() | ||
self.base_regressor = base_regressor | ||
# get the base regressor | ||
self.base_regressor_class = base_regressor_class | ||
self.base_regressor = base_regressor_class( | ||
*args, **kwargs | ||
) | ||
|
||
# initialize measurement parameter | ||
self.log_sigma_measurement = torch.nn.Parameter(torch.zeros(1)) | ||
|
||
def g(self, func_value=None, test_ligand_concentration=1e-3): | ||
return 1 / (1 + torch.exp(-func_value) / test_ligand_concentration) | ||
def _get_measurement(self, delta_g, concentration=1e-3): | ||
""" Translate ..math:: \Delta G to percentage inhibition. | ||
|
||
Parameters | ||
---------- | ||
delta_g : torch.Tensor, shape=(number_of_graphs, 1) | ||
Binding free energy. | ||
|
||
concentration : torch.Tensor, shape=(,) or (1, number_of_concentrations) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document units! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. again, I'm to blame. |
||
Concentration of ligand. | ||
|
||
Returns | ||
------- | ||
measurement : torch.Tensor, shape=(number_of_graphs, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there's a typo in this equation. If you want percent inhibition, you'll want something like 100 * (1 - 1.0 / (1.0 + concentration / torch.exp(delta_g)) where This means that when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So what I did was copy the function g from Josh's notebook ( https://gist.github.com/maxentile/fdc42d76c1125aff33f912fbc28edb1f ) Let me check out how your comments map to that code so I understand your comment better. The original code in Josh's notebook is the following, with the math as explained beneat:
@jchodera Can you explain the deviation from what you wrote above so I can parse my mistake? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, while we're at it, I only ever got Josh's notebook as a reference for this function. Where can I look up the published version of this so I can check the math independently? I should have asked for this earlier to debug my own code, I rushed this too much. |
||
or (number_of_graphs, number_of_concentrations) | ||
Percentage of inhibition. | ||
|
||
""" | ||
return 1.0 / (1.0 + torch.exp(-delta_g) / concentration) | ||
|
||
def _condition_delta_g(self, x_te, *args, **kwargs): | ||
""" Returns predictive distribution of binding free energy. """ | ||
return self.base_regressor.condition(x_te, *args, **kwargs) | ||
|
||
def _condition_measurement(self, x_te=None, concentration=1e-3, delta_g_sample=None): | ||
""" Returns predictive distribution of percentage of inhibtiion. """ | ||
# sample delta g if not specified | ||
if delta_g_sample is None and x_te is not None: | ||
delta_g_sample = self._condition_delta_g(x_te).rsample() | ||
|
||
def condition( | ||
self, h=None, test_ligand_concentration=1e-3, *args, **kwargs | ||
): | ||
distribution_base_regressor = self.base_regressor.condition( | ||
h, *args, **kwargs | ||
) | ||
# we sample from the latent f to push things through the likelihood | ||
# Note: if we do this, | ||
# in order to get good estimates of LLK | ||
# we may need to draw multiple samples | ||
f_sample = distribution_base_regressor.rsample() | ||
mu_m = self.g( | ||
func_value=f_sample, | ||
test_ligand_concentration=test_ligand_concentration, | ||
) | ||
sigma_m = torch.exp(self.log_sigma_measurement) | ||
distribution_measurement = torch.distributions.normal.Normal( | ||
loc=mu_m, scale=sigma_m | ||
loc=self._get_measurement(delta_g_sample, concentration=concentration), | ||
scale=self.log_sigma_measurement.exp() | ||
) | ||
# import pdb; pdb.set_trace() | ||
return distribution_measurement | ||
|
||
def loss( | ||
self, h=None, y=None, test_ligand_concentration=None, *args, **kwargs | ||
def _condition_ic_50( | ||
self, | ||
x_te=None, | ||
delta_g_sample=None, | ||
concentration_low=0.0, | ||
concentration_high=1.0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do these parameters mean? Can you document them in a docstring? We definitely do not want a uniform distribution over IC50 as a prior. It should be uniform over So maybe 0-14 in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks John, I think this was sth @yuanqing-wang just wanted to ask me stuff about that he hacked in just today. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry @jchodera this is just a very immature draft I had in order to discuss some ideas. For the IC50 I was only thinking about how can we integrate that into the graphical model without using the Cheng-Prussoff equation. So I'm thinking about doing the titration curve here from |
||
number_of_concentrations=1024, | ||
): | ||
# import pdb; pdb.set_trace() | ||
distribution_measurement = self.condition( | ||
h=h, | ||
test_ligand_concentration=test_ligand_concentration, | ||
*args, | ||
**kwargs | ||
) | ||
loss_measurement = -distribution_measurement.log_prob(y).sum() | ||
# import pdb; pdb.set_trace() | ||
return loss_measurement | ||
""" Returns predictive distribution of ic50 """ | ||
# sample delta g if not specified | ||
if delta_g_sample is None and x_te is not None: | ||
delta_g_sample = self._condition_delta_g(x_te).rsample() | ||
|
||
def marginal_sample( | ||
self, h=None, n_samples=100, test_ligand_concentration=1e-3, **kwargs | ||
): | ||
distribution_base_regressor = self.base_regressor.condition( | ||
h, **kwargs | ||
# get the possible array of concentration | ||
concentration = torch.linspace( | ||
start=concentration_low, | ||
end=concentration_high, | ||
steps=number_of_concentrations) | ||
|
||
distribution_measurement = torch.distributions.normal.Normal( | ||
loc=self._get_measurement(delta_g_sample, concentration=concentration), | ||
scale=self.log_sigma_measurement.exp(), | ||
) | ||
samples_measurement = [] | ||
for ns in range(n_samples): | ||
f_sample = distribution_base_regressor.rsample() | ||
mu_m = self.g( | ||
func_value=f_sample, | ||
test_ligand_concentration=test_ligand_concentration, | ||
) | ||
sigma_m = torch.exp(self.log_sigma_measurement) | ||
distribution_measurement = torch.distributions.normal.Normal( | ||
loc=mu_m, scale=sigma_m | ||
) | ||
samples_measurement.append(distribution_measurement.sample()) | ||
return samples_measurement | ||
|
||
def marginal_loss( | ||
self, | ||
h=None, | ||
y=None, | ||
test_ligand_concentration=1e-3, | ||
n_samples=10, | ||
**kwargs | ||
|
||
# (number_of_graphs, number_of_concentrations) | ||
measurement_sample = distribution_measurement.rsample() | ||
measurement_sample_sorted = measurement_sample.sort(dimension=1) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
def condition( | ||
self, | ||
*args, | ||
output="measurement", | ||
**kwargs, | ||
): | ||
""" | ||
sample n_samples often from loss in order to get a better approximation | ||
""" | ||
distribution_base_regressor = self.base_regressor.condition( | ||
h, **kwargs | ||
) | ||
marginal_loss_measurement = 0 | ||
for ns in range(n_samples): | ||
f_sample = distribution_base_regressor.rsample() | ||
mu_m = self.g( | ||
func_value=f_sample, | ||
test_ligand_concentration=test_ligand_concentration, | ||
) | ||
sigma_m = torch.exp(self.log_sigma_measurement) | ||
distribution_measurement = torch.distributions.normal.Normal( | ||
loc=mu_m, scale=sigma_m | ||
) | ||
marginal_loss_measurement += -distribution_measurement.log_prob(y) | ||
marginal_loss_measurement /= n_samples | ||
return marginal_loss_measurement | ||
""" Public method for predictive distribution construction. """ | ||
if output == "measurement": | ||
return self._condition_measurement(*args, **kwargs) | ||
|
||
elif output == "delta_g": | ||
return self._condition_delta_g(*args, **kwargs) | ||
|
||
elif output == "ic50": | ||
return self._condition_ic50(*args, **kwargs) | ||
|
||
else: | ||
raise RuntimeError('We only support condition measurement and delta g') | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yuanqing-wang : What units are you using here? It would be good to document them in the docstring!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's my code, I will add them