diff --git a/pinot/net.py b/pinot/net.py index 3c25bf08..89b714d4 100644 --- a/pinot/net.py +++ b/pinot/net.py @@ -237,49 +237,12 @@ def condition(self, g, *args, **kwargs): if 'sampler' in kwargs: sampler = kwargs.pop('sampler') - if 'n_samples' in kwargs: - n_samples = kwargs.pop('n_samples') - if self.has_exact_gp is True: h_last = self.representation(self.g_last) kwargs = {**{"x_tr": h_last, "y_tr": self.y_last}, **kwargs} - if sampler is None: - return self._condition(h, *args, **kwargs) - - if not hasattr(sampler, "sample_params"): - return self._condition(h, *args, **kwargs) - - # initialize a list of distributions - distributions = [] - - for _ in range(n_samples): + if sampler is not None and hasattr(sampler, "sample_params"): sampler.sample_params() - distributions.append(self._condition(g, *args, **kwargs)) - - # get the parameter of these distributions - # NOTE: this is not necessarily the most efficienct solution - # since we don't know the memory footprint of - # torch.distributions - mus, sigmas = zip( - *[ - (distribution.loc, distribution.scale) - for distribution in distributions - ] - ) - # concat parameters together - # (n_samples, batch_size, measurement_dimension) - mu = torch.stack(mus).cpu() # distribution no cuda - sigma = torch.stack(sigmas).cpu() + return self._condition(h, *args, **kwargs) - # construct the distribution - distribution = torch.distributions.normal.Normal(loc=mu, scale=sigma) - - # make it mixture - distribution = torch.distributions.mixture_same_family.MixtureSameFamily( - torch.distributions.Categorical(torch.ones(mu.shape[0],)), - torch.distributions.Independent(distribution, 2), - ) - - return distribution diff --git a/pinot/regressors/biophysical_regressor.py b/pinot/regressors/biophysical_regressor.py index 850556ab..0bfb714a 100644 --- a/pinot/regressors/biophysical_regressor.py +++ b/pinot/regressors/biophysical_regressor.py @@ -6,11 +6,12 @@ import abc import math import numpy as np +from pinot.regressors.base_regressor import BaseRegressor # ============================================================================= # MODULE CLASSES # ============================================================================= -class BiophysicalRegressor(torch.nn.Module): +class BiophysicalRegressor(BaseRegressor): r""" Biophysically inspired model Parameters @@ -23,95 +24,103 @@ class BiophysicalRegressor(torch.nn.Module): """ - def __init__(self, base_regressor=None, *args, **kwargs): + def __init__(self, base_regressor_class=None, *args, **kwargs): super(BiophysicalRegressor, self).__init__() - self.base_regressor = base_regressor + # get the base regressor + self.base_regressor_class = base_regressor_class + self.base_regressor = base_regressor_class( + *args, **kwargs + ) + + # initialize measurement parameter self.log_sigma_measurement = torch.nn.Parameter(torch.zeros(1)) - def g(self, func_value=None, test_ligand_concentration=1e-3): - return 1 / (1 + torch.exp(-func_value) / test_ligand_concentration) + def _get_measurement(self, delta_g, concentration=1e-3): + """ Translate ..math:: \Delta G to percentage inhibition. + + Parameters + ---------- + delta_g : torch.Tensor, shape=(number_of_graphs, 1) + Binding free energy. + + concentration : torch.Tensor, shape=(,) or (1, number_of_concentrations) + Concentration of ligand. + + Returns + ------- + measurement : torch.Tensor, shape=(number_of_graphs, 1) + or (number_of_graphs, number_of_concentrations) + Percentage of inhibition. + + """ + return 1.0 / (1.0 + torch.exp(-delta_g) / concentration) + + def _condition_delta_g(self, x_te, *args, **kwargs): + """ Returns predictive distribution of binding free energy. """ + return self.base_regressor.condition(x_te, *args, **kwargs) + + def _condition_measurement(self, x_te=None, concentration=1e-3, delta_g_sample=None): + """ Returns predictive distribution of percentage of inhibtiion. """ + # sample delta g if not specified + if delta_g_sample is None and x_te is not None: + delta_g_sample = self._condition_delta_g(x_te).rsample() - def condition( - self, h=None, test_ligand_concentration=1e-3, *args, **kwargs - ): - distribution_base_regressor = self.base_regressor.condition( - h, *args, **kwargs - ) - # we sample from the latent f to push things through the likelihood - # Note: if we do this, - # in order to get good estimates of LLK - # we may need to draw multiple samples - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m + loc=self._get_measurement(delta_g_sample, concentration=concentration), + scale=self.log_sigma_measurement.exp() ) - # import pdb; pdb.set_trace() - return distribution_measurement - def loss( - self, h=None, y=None, test_ligand_concentration=None, *args, **kwargs + def _condition_ic_50( + self, + x_te=None, + delta_g_sample=None, + concentration_low=0.0, + concentration_high=1.0, + number_of_concentrations=1024, ): - # import pdb; pdb.set_trace() - distribution_measurement = self.condition( - h=h, - test_ligand_concentration=test_ligand_concentration, - *args, - **kwargs - ) - loss_measurement = -distribution_measurement.log_prob(y).sum() - # import pdb; pdb.set_trace() - return loss_measurement + """ Returns predictive distribution of ic50 """ + # sample delta g if not specified + if delta_g_sample is None and x_te is not None: + delta_g_sample = self._condition_delta_g(x_te).rsample() - def marginal_sample( - self, h=None, n_samples=100, test_ligand_concentration=1e-3, **kwargs - ): - distribution_base_regressor = self.base_regressor.condition( - h, **kwargs + # get the possible array of concentration + concentration = torch.linspace( + start=concentration_low, + end=concentration_high, + steps=number_of_concentrations) + + distribution_measurement = torch.distributions.normal.Normal( + loc=self._get_measurement(delta_g_sample, concentration=concentration), + scale=self.log_sigma_measurement.exp(), ) - samples_measurement = [] - for ns in range(n_samples): - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) - distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m - ) - samples_measurement.append(distribution_measurement.sample()) - return samples_measurement - - def marginal_loss( - self, - h=None, - y=None, - test_ligand_concentration=1e-3, - n_samples=10, - **kwargs + + # (number_of_graphs, number_of_concentrations) + measurement_sample = distribution_measurement.rsample() + measurement_sample_sorted = measurement_sample.sort(dimension=1) + + + + + + + + def condition( + self, + *args, + output="measurement", + **kwargs, ): - """ - sample n_samples often from loss in order to get a better approximation - """ - distribution_base_regressor = self.base_regressor.condition( - h, **kwargs - ) - marginal_loss_measurement = 0 - for ns in range(n_samples): - f_sample = distribution_base_regressor.rsample() - mu_m = self.g( - func_value=f_sample, - test_ligand_concentration=test_ligand_concentration, - ) - sigma_m = torch.exp(self.log_sigma_measurement) - distribution_measurement = torch.distributions.normal.Normal( - loc=mu_m, scale=sigma_m - ) - marginal_loss_measurement += -distribution_measurement.log_prob(y) - marginal_loss_measurement /= n_samples - return marginal_loss_measurement + """ Public method for predictive distribution construction. """ + if output == "measurement": + return self._condition_measurement(*args, **kwargs) + + elif output == "delta_g": + return self._condition_delta_g(*args, **kwargs) + + elif output == "ic50": + return self._condition_ic50(*args, **kwargs) + + else: + raise RuntimeError('We only support condition measurement and delta g') + + diff --git a/pinot/regressors/tests/test_biophysical_regressor.py b/pinot/regressors/tests/test_biophysical_regressor.py index d9b5ee60..90b2973b 100644 --- a/pinot/regressors/tests/test_biophysical_regressor.py +++ b/pinot/regressors/tests/test_biophysical_regressor.py @@ -14,7 +14,8 @@ def test_init(): def net(): import pinot regressor = pinot.regressors.biophysical_regressor.BiophysicalRegressor( - base_regressor=pinot.regressors.NeuralNetworkRegressor(32), + base_regressor_class=pinot.regressors.NeuralNetworkRegressor, + in_features=32, ) representation = pinot.representation.Sequential(