From 171bcf45bf650e0f5a87383ad8bf9db77ccc68ab Mon Sep 17 00:00:00 2001 From: yuanqing-wang Date: Sat, 18 Jul 2020 22:58:12 -0400 Subject: [PATCH 1/5] biophysical regressor --- pinot/regressors/biophysical_regressor.py | 117 +++++++--------------- 1 file changed, 34 insertions(+), 83 deletions(-) diff --git a/pinot/regressors/biophysical_regressor.py b/pinot/regressors/biophysical_regressor.py index 850556ab..979798d7 100644 --- a/pinot/regressors/biophysical_regressor.py +++ b/pinot/regressors/biophysical_regressor.py @@ -6,11 +6,12 @@ import abc import math import numpy as np +from pinot.regressors.base_regressor import BaseRegressor # ============================================================================= # MODULE CLASSES # ============================================================================= -class BiophysicalRegressor(torch.nn.Module): +class BiophysicalRegressor(BaseRegressor): r""" Biophysically inspired model Parameters @@ -23,95 +24,45 @@ class BiophysicalRegressor(torch.nn.Module): """ - def __init__(self, base_regressor=None, *args, **kwargs): + def __init__(self, base_regressor_class=None, *args, **kwargs): super(BiophysicalRegressor, self).__init__() - self.base_regressor = base_regressor + # get the base regressor + self.base_regressor_class = base_regressor_class + self.base_regressor = base_regressor_class( + *args, **kwargs + ) + + # initialize measurement parameter self.log_sigma_measurement = torch.nn.Parameter(torch.zeros(1)) - def g(self, func_value=None, test_ligand_concentration=1e-3): - return 1 / (1 + torch.exp(-func_value) / test_ligand_concentration) + def _get_measurement(self, delta_g, concentration=1e-3): + return 1.0 / (1.0 + torch.exp(-delta_g) / concentration) + + def _condition_delta_g(self, x_te, *args, **kwargs): + return self.base_regressor.condition(x_te, *args, **kwargs) + + def _condition_measurement(self, concentration=1e-3, _delta_g_sample=None): + if _delta_g_sample is None: + _delta_g_sample = self._condition_delta_g() - def condition( - self, h=None, test_ligand_concentration=1e-3, *args, **kwargs - ): - distribution_base_regressor = self.base_regressor.condition( - h, *args, **kwargs - ) - # we sample from the latent f to push things through the likelihood - # Note: if we do this, - # in order to get good estimates of LLK - # we may need to draw multiple samples - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m + loc=self._get_measurement(delta_g, concentration=concentration), + scale=self.log_sigma_measurement.exp() ) - # import pdb; pdb.set_trace() - return distribution_measurement - def loss( - self, h=None, y=None, test_ligand_concentration=None, *args, **kwargs + def condition( + self, + *args, + output="measurement", + **kwargs. ): - # import pdb; pdb.set_trace() - distribution_measurement = self.condition( - h=h, - test_ligand_concentration=test_ligand_concentration, - *args, - **kwargs - ) - loss_measurement = -distribution_measurement.log_prob(y).sum() - # import pdb; pdb.set_trace() - return loss_measurement + if output == "measurement": + return self._condition_measurement(self, *args, **kwargs) - def marginal_sample( - self, h=None, n_samples=100, test_ligand_concentration=1e-3, **kwargs - ): - distribution_base_regressor = self.base_regressor.condition( - h, **kwargs - ) - samples_measurement = [] - for ns in range(n_samples): - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) - distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m - ) - samples_measurement.append(distribution_measurement.sample()) - return samples_measurement + elif output == "delta_g": + return self._condition_delta_g(self, *args, **kwargs) - def marginal_loss( - self, - h=None, - y=None, - test_ligand_concentration=1e-3, - n_samples=10, - **kwargs - ): - """ - sample n_samples often from loss in order to get a better approximation - """ - distribution_base_regressor = self.base_regressor.condition( - h, **kwargs - ) - marginal_loss_measurement = 0 - for ns in range(n_samples): - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) - distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m - ) - marginal_loss_measurement += -distribution_measurement.log_prob(y) - marginal_loss_measurement /= n_samples - return marginal_loss_measurement + else: + raise RuntimeError('We only support condition measurement and delta g") + + From e4ff3a8aaff42c579e0637b3694985a780b19b3a Mon Sep 17 00:00:00 2001 From: yuanqing-wang Date: Sun, 19 Jul 2020 00:15:32 -0400 Subject: [PATCH 2/5] bug fix --- pinot/regressors/biophysical_regressor.py | 16 ++++++++-------- .../tests/test_biophysical_regressor.py | 3 ++- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/pinot/regressors/biophysical_regressor.py b/pinot/regressors/biophysical_regressor.py index 979798d7..6bd850e3 100644 --- a/pinot/regressors/biophysical_regressor.py +++ b/pinot/regressors/biophysical_regressor.py @@ -41,12 +41,12 @@ def _get_measurement(self, delta_g, concentration=1e-3): def _condition_delta_g(self, x_te, *args, **kwargs): return self.base_regressor.condition(x_te, *args, **kwargs) - def _condition_measurement(self, concentration=1e-3, _delta_g_sample=None): - if _delta_g_sample is None: - _delta_g_sample = self._condition_delta_g() + def _condition_measurement(self, x_te, concentration=1e-3, delta_g_sample=None): + if delta_g_sample is None: + delta_g_sample = self._condition_delta_g(x_te).rsample() distribution_measurement = torch.distributions.normal.Normal( - loc=self._get_measurement(delta_g, concentration=concentration), + loc=self._get_measurement(delta_g_sample, concentration=concentration), scale=self.log_sigma_measurement.exp() ) @@ -54,15 +54,15 @@ def condition( self, *args, output="measurement", - **kwargs. + **kwargs, ): if output == "measurement": - return self._condition_measurement(self, *args, **kwargs) + return self._condition_measurement(*args, **kwargs) elif output == "delta_g": - return self._condition_delta_g(self, *args, **kwargs) + return self._condition_delta_g(*args, **kwargs) else: - raise RuntimeError('We only support condition measurement and delta g") + raise RuntimeError('We only support condition measurement and delta g') diff --git a/pinot/regressors/tests/test_biophysical_regressor.py b/pinot/regressors/tests/test_biophysical_regressor.py index d9b5ee60..90b2973b 100644 --- a/pinot/regressors/tests/test_biophysical_regressor.py +++ b/pinot/regressors/tests/test_biophysical_regressor.py @@ -14,7 +14,8 @@ def test_init(): def net(): import pinot regressor = pinot.regressors.biophysical_regressor.BiophysicalRegressor( - base_regressor=pinot.regressors.NeuralNetworkRegressor(32), + base_regressor_class=pinot.regressors.NeuralNetworkRegressor, + in_features=32, ) representation = pinot.representation.Sequential( From f0cfeecbefcb6602bf353b77192dec8617df4fdf Mon Sep 17 00:00:00 2001 From: yuanqing-wang Date: Sun, 19 Jul 2020 00:36:01 -0400 Subject: [PATCH 3/5] ic50 integration --- pinot/regressors/biophysical_regressor.py | 61 ++++++++++++++++++++++- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/pinot/regressors/biophysical_regressor.py b/pinot/regressors/biophysical_regressor.py index 6bd850e3..617a09b0 100644 --- a/pinot/regressors/biophysical_regressor.py +++ b/pinot/regressors/biophysical_regressor.py @@ -36,13 +36,33 @@ def __init__(self, base_regressor_class=None, *args, **kwargs): self.log_sigma_measurement = torch.nn.Parameter(torch.zeros(1)) def _get_measurement(self, delta_g, concentration=1e-3): + """ Translate ..math:: \Delta G to percentage inhibition. + + Parameters + ---------- + delta_g : torch.Tensor, shape=(number_of_graphs, 1) + Binding free energy. + + concentration : torch.Tensor, shape=(,) or (1, number_of_concentrations) + Concentration of ligand. + + Returns + ------- + measurement : torch.Tensor, shape=(number_of_graphs, 1) + or (number_of_graphs, number_of_concentrations) + Percentage of inhibition. + + """ return 1.0 / (1.0 + torch.exp(-delta_g) / concentration) def _condition_delta_g(self, x_te, *args, **kwargs): + """ Returns predictive distribution of binding free energy. """ return self.base_regressor.condition(x_te, *args, **kwargs) - def _condition_measurement(self, x_te, concentration=1e-3, delta_g_sample=None): - if delta_g_sample is None: + def _condition_measurement(self, x_te=None, concentration=1e-3, delta_g_sample=None): + """ Returns predictive distribution of percentage of inhibtiion. """ + # sample delta g if not specified + if delta_g_sample is None and x_te is not None: delta_g_sample = self._condition_delta_g(x_te).rsample() distribution_measurement = torch.distributions.normal.Normal( @@ -50,18 +70,55 @@ def _condition_measurement(self, x_te, concentration=1e-3, delta_g_sample=None): scale=self.log_sigma_measurement.exp() ) + def _condition_ic_50( + self, + x_te=None, + delta_g_sample=None, + concentration_low=0.0, + concentration_high=1.0, + number_of_concentrations=1024, + ): + """ Returns predictive distribution of ic50 """ + # sample delta g if not specified + if delta_g_sample is None and x_te is not None: + delta_g_sample = self._condition_delta_g(x_te).rsample() + + # get the possible array of concentration + concentration = torch.linspace( + start=concentration_low, + end=concentration_high, + steps=number_of_concentrations) + + distribution_measurement = torch.distributions.normal.Normal( + loc=self._get_measurement(delta_g_sample, concentration=concentration), + scale=self.log_sigma_measurement.exp(), + ) + + # (number_of_concentrations, 1) + measurement_sample = distribution_measurement.rsample() + + + + + + + def condition( self, *args, output="measurement", **kwargs, ): + """ Public method for predictive distribution construction. """ if output == "measurement": return self._condition_measurement(*args, **kwargs) elif output == "delta_g": return self._condition_delta_g(*args, **kwargs) + elif output == "ic50": + return self._condition_ic50(*args, **kwargs) + else: raise RuntimeError('We only support condition measurement and delta g') From 51878e1a2228b0c406addc1ac2f673b8f34aad9d Mon Sep 17 00:00:00 2001 From: yuanqing-wang Date: Sun, 19 Jul 2020 00:38:59 -0400 Subject: [PATCH 4/5] ic50 integration --- pinot/regressors/biophysical_regressor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pinot/regressors/biophysical_regressor.py b/pinot/regressors/biophysical_regressor.py index 617a09b0..0bfb714a 100644 --- a/pinot/regressors/biophysical_regressor.py +++ b/pinot/regressors/biophysical_regressor.py @@ -94,8 +94,9 @@ def _condition_ic_50( scale=self.log_sigma_measurement.exp(), ) - # (number_of_concentrations, 1) + # (number_of_graphs, number_of_concentrations) measurement_sample = distribution_measurement.rsample() + measurement_sample_sorted = measurement_sample.sort(dimension=1) From fffca5bae2579a12df42ef108353a7d936f57fdd Mon Sep 17 00:00:00 2001 From: yuanqing-wang Date: Mon, 20 Jul 2020 20:12:52 -0400 Subject: [PATCH 5/5] sampler --- pinot/net.py | 41 ++--------------------------------------- 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/pinot/net.py b/pinot/net.py index 3c25bf08..89b714d4 100644 --- a/pinot/net.py +++ b/pinot/net.py @@ -237,49 +237,12 @@ def condition(self, g, *args, **kwargs): if 'sampler' in kwargs: sampler = kwargs.pop('sampler') - if 'n_samples' in kwargs: - n_samples = kwargs.pop('n_samples') - if self.has_exact_gp is True: h_last = self.representation(self.g_last) kwargs = {**{"x_tr": h_last, "y_tr": self.y_last}, **kwargs} - if sampler is None: - return self._condition(h, *args, **kwargs) - - if not hasattr(sampler, "sample_params"): - return self._condition(h, *args, **kwargs) - - # initialize a list of distributions - distributions = [] - - for _ in range(n_samples): + if sampler is not None and hasattr(sampler, "sample_params"): sampler.sample_params() - distributions.append(self._condition(g, *args, **kwargs)) - - # get the parameter of these distributions - # NOTE: this is not necessarily the most efficienct solution - # since we don't know the memory footprint of - # torch.distributions - mus, sigmas = zip( - *[ - (distribution.loc, distribution.scale) - for distribution in distributions - ] - ) - # concat parameters together - # (n_samples, batch_size, measurement_dimension) - mu = torch.stack(mus).cpu() # distribution no cuda - sigma = torch.stack(sigmas).cpu() + return self._condition(h, *args, **kwargs) - # construct the distribution - distribution = torch.distributions.normal.Normal(loc=mu, scale=sigma) - - # make it mixture - distribution = torch.distributions.mixture_same_family.MixtureSameFamily( - torch.distributions.Categorical(torch.ones(mu.shape[0],)), - torch.distributions.Independent(distribution, 2), - ) - - return distribution