From 70c7dd7a251c1a58921ebbacd5e7ced423766189 Mon Sep 17 00:00:00 2001 From: mschrader15 Date: Mon, 25 Sep 2023 11:10:33 -0500 Subject: [PATCH] feat: halfnormal and lognormal Signed-off-by: mschrader15 --- pomegranate/distributions/__init__.py | 2 + pomegranate/distributions/halfnormal.py | 121 +++------------ pomegranate/distributions/lognormal.py | 159 ++------------------ pomegranate/distributions/lognormal_old.pyx | 112 -------------- 4 files changed, 29 insertions(+), 365 deletions(-) delete mode 100644 pomegranate/distributions/lognormal_old.pyx diff --git a/pomegranate/distributions/__init__.py b/pomegranate/distributions/__init__.py index c73fe65d..0c57f351 100644 --- a/pomegranate/distributions/__init__.py +++ b/pomegranate/distributions/__init__.py @@ -11,3 +11,5 @@ from .student_t import StudentT from .uniform import Uniform from .zero_inflated import ZeroInflated +from .lognormal import LogNormal +from .halfnormal import HalfNormal \ No newline at end of file diff --git a/pomegranate/distributions/halfnormal.py b/pomegranate/distributions/halfnormal.py index b3fa0f9d..93322968 100644 --- a/pomegranate/distributions/halfnormal.py +++ b/pomegranate/distributions/halfnormal.py @@ -10,6 +10,7 @@ from .._utils import _check_shapes from ._distribution import Distribution +from .normal import Normal # Define some useful constants @@ -17,9 +18,10 @@ INF = float("inf") SQRT_2_PI = 2.50662827463 LOG_2_PI = 1.83787706641 +LOG_2 = 0.6931471805599453 -class HalfNormal(Distribution): +class HalfNormal(Normal): """A half-normal distribution object. A half-normal distribution is a distribution over positive real numbers that @@ -78,20 +80,12 @@ def __init__( frozen=False, check_data=True, ): - super().__init__(inertia=inertia, frozen=frozen, check_data=check_data) - self.name = "HalfNormal" - - self.covs = _check_parameter(_cast_as_parameter(covs), "covs", ndim=(1, 2)) - - _check_shapes([self.means, self.covs], ["means", "covs"]) - - self.min_cov = _check_parameter(min_cov, "min_cov", min_value=0, ndim=0) - self.covariance_type = covariance_type - - self._initialized = covs is not None - self.d = self.means.shape[-1] if self._initialized else None - self._reset_cache() + self.name = "HalfNormal" + super().__init__(means=None, covs=covs, min_cov=min_cov, + covariance_type=covariance_type, inertia=inertia, frozen=frozen, + check_data=check_data) + def _initialize(self, d): """Initialize the probability distribution. @@ -105,20 +99,6 @@ def _initialize(self, d): d: int The dimensionality the distribution is being initialized to. """ - if self.covariance_type == "full": - self.covs = _cast_as_parameter( - torch.zeros(d, d, dtype=self.dtype, device=self.device) - ) - elif self.covariance_type == "diag": - self.covs = _cast_as_parameter( - torch.zeros(d, dtype=self.dtype, device=self.device) - ) - elif self.covariance_type == "sphere": - self.covs = _cast_as_parameter( - torch.tensor(0, dtype=self.dtype, device=self.device) - ) - - self._initialized = True super()._initialize(d) def _reset_cache(self): @@ -129,51 +109,7 @@ def _reset_cache(self): recalculates the cached values meant to speed up log probability calculations. """ - - if self._initialized == False: - return - - self.register_buffer( - "_w_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device) - ) - self.register_buffer( - "_xw_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device) - ) - - if self.covariance_type == "full": - self.register_buffer( - "_xxw_sum", - torch.zeros(self.d, self.d, dtype=self.dtype, device=self.device), - ) - - if self.covs.sum() > 0.0: - chol = torch.linalg.cholesky(self.covs) - _inv_cov = torch.linalg.solve_triangular( - chol, - torch.eye(len(self.covs), dtype=self.dtype, device=self.device), - upper=False, - ).T - _log_det = -0.5 * torch.linalg.slogdet(self.covs)[1] - _theta = _log_det - 0.5 * (self.d * LOG_2_PI) - - self.register_buffer("_inv_cov", _inv_cov) - self.register_buffer("_log_det", _log_det) - self.register_buffer("_theta", _theta) - - elif self.covariance_type in ("diag", "sphere"): - self.register_buffer( - "_xxw_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device) - ) - - if self.covs.sum() > 0.0: - _log_sigma_sqrt_2pi = -torch.log(torch.sqrt(self.covs) * SQRT_2_PI) - _inv_two_sigma = 1.0 / (2 * self.covs) - - self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi) - self.register_buffer("_inv_two_sigma", _inv_two_sigma) - - if torch.any(self.covs < 0): - raise ValueError("Variances must be positive.") + super()._reset_cache() def sample(self, n): """Sample from the probability distribution. @@ -193,10 +129,7 @@ def sample(self, n): X: torch.tensor, shape=(n, self.d) Randomly generated samples. """ - - if self.covariance_type == "diag": - return torch.distributions.HalfNormal(self.covs).sample([n]) - elif self.covariance_type == "full": + if self.covariance_type in ["diag", "full"]: return torch.distributions.HalfNormal(self.covs).sample([n]) def log_probability(self, X): @@ -225,23 +158,15 @@ def log_probability(self, X): """ X = _check_parameter( - _cast_as_tensor(X, dtype=self.means.dtype), + _cast_as_tensor(X, dtype=self.covs.dtype), "X", ndim=2, shape=(-1, self.d), check_parameter=self.check_data, ) + return super().log_probability(X) + LOG_2 + - # if self.covariance_type == 'full': - # # logp = torch.matmul(X, self._inv_cov) - self._inv_cov_dot_mu - # # logp = self.d * LOG_2_PI + torch.sum(logp ** 2, dim=-1) - # # logp = self._log_det - 0.5 * logp - # # return logp - return 0.5 * LOG_2_PI + (X**2 / 2).sum(dim=-1) - - # elif self.covariance_type in ('diag', 'sphere'): - # return torch.sum(self._log_sigma_sqrt_2pi - ((X - self.means) ** 2) - # * self._inv_two_sigma, dim=-1) def summarize(self, X, sample_weight=None): """Extract the sufficient statistics from a batch of data. @@ -263,21 +188,7 @@ def summarize(self, X, sample_weight=None): (-1, self.d) or a vector of shape (-1,). Default is ones. """ - if self.frozen == True: - return - - X, sample_weight = super().summarize(X, sample_weight=sample_weight) - X = _cast_as_tensor(X, dtype=self.means.dtype) - - if self.covariance_type == "full": - self._w_sum += torch.sum(sample_weight, dim=0) - self._xw_sum += torch.sum(X * sample_weight, axis=0) - self._xxw_sum += torch.matmul((X * sample_weight).T, X) - - elif self.covariance_type in ("diag", "sphere"): - self._w_sum[:] = self._w_sum + torch.sum(sample_weight, dim=0) - self._xw_sum[:] = self._xw_sum + torch.sum(X * sample_weight, dim=0) - self._xxw_sum[:] = self._xxw_sum + torch.sum(X**2 * sample_weight, dim=0) + super().summarize(X, sample_weight=sample_weight) def from_summaries(self): """Update the model parameters given the extracted statistics. @@ -293,6 +204,9 @@ def from_summaries(self): if self.frozen == True: return + # the means are always zero for a half normal distribution + means = torch.zeros(self.d, dtype=self.covs.dtype) + if self.covariance_type == "full": v = self._xw_sum.unsqueeze(0) * self._xw_sum.unsqueeze(1) covs = self._xxw_sum / self._w_sum - v / self._w_sum**2.0 @@ -305,4 +219,5 @@ def from_summaries(self): covs = covs.mean(dim=-1) _update_parameter(self.covs, covs, self.inertia) + _update_parameter(self.means, means, self.inertia) self._reset_cache() diff --git a/pomegranate/distributions/lognormal.py b/pomegranate/distributions/lognormal.py index 307f5c76..92673507 100644 --- a/pomegranate/distributions/lognormal.py +++ b/pomegranate/distributions/lognormal.py @@ -9,7 +9,7 @@ from .._utils import _check_parameter from .._utils import _check_shapes -from ._distribution import Distribution +from .normal import Normal # Define some useful constants @@ -19,7 +19,7 @@ LOG_2_PI = 1.83787706641 -class LogNormal(Distribution): +class LogNormal(Normal): """A lognormal object. The parameters are the mu and sigma of the normal distribution, which @@ -71,102 +71,10 @@ class LogNormal(Distribution): def __init__(self, means=None, covs=None, covariance_type='full', min_cov=None, inertia=0.0, frozen=False, check_data=True): - super().__init__(inertia=inertia, frozen=frozen, check_data=check_data) - self.name = "LogNormal" - - self.means = _check_parameter(_cast_as_parameter(means), "means", - ndim=1) - self.covs = _check_parameter(_cast_as_parameter(covs), "covs", - ndim=(1, 2)) - - _check_shapes([self.means, self.covs], ["means", "covs"]) - - self.min_cov = _check_parameter(min_cov, "min_cov", min_value=0, ndim=0) - self.covariance_type = covariance_type - - self._initialized = (means is not None) and (covs is not None) - self.d = self.means.shape[-1] if self._initialized else None - self._reset_cache() - - def _initialize(self, d): - """Initialize the probability distribution. - - This method is meant to only be called internally. It initializes the - parameters of the distribution and stores its dimensionality. For more - complex methods, this function will do more. - - - Parameters - ---------- - d: int - The dimensionality the distribution is being initialized to. - """ - - self.means = _cast_as_parameter(torch.zeros(d, dtype=self.dtype, - device=self.device)) - if self.covariance_type == 'full': - self.covs = _cast_as_parameter(torch.zeros(d, d, - dtype=self.dtype, device=self.device)) - elif self.covariance_type == 'diag': - self.covs = _cast_as_parameter(torch.zeros(d, dtype=self.dtype, - device=self.device)) - elif self.covariance_type == 'sphere': - self.covs = _cast_as_parameter(torch.tensor(0, dtype=self.dtype, - device=self.device)) - - self._initialized = True - super()._initialize(d) - - def _reset_cache(self): - """Reset the internally stored statistics. - - This method is meant to only be called internally. It resets the - stored statistics used to update the model parameters as well as - recalculates the cached values meant to speed up log probability - calculations. - """ - - if self._initialized == False: - return - - self.register_buffer("_w_sum", torch.zeros(self.d, dtype=self.dtype, - device=self.device)) - self.register_buffer("_xw_sum", torch.zeros(self.d, dtype=self.dtype, - device=self.device)) - - if self.covariance_type == 'full': - self.register_buffer("_xxw_sum", torch.zeros(self.d, self.d, - dtype=self.dtype, device=self.device)) - - if self.covs.sum() > 0.0: - chol = torch.linalg.cholesky(self.covs) - _inv_cov = torch.linalg.solve_triangular(chol, torch.eye( - len(self.covs), dtype=self.dtype, device=self.device), - upper=False).T - _inv_cov_dot_mu = torch.matmul(self.means, _inv_cov) - _log_det = -0.5 * torch.linalg.slogdet(self.covs)[1] - _theta = _log_det - 0.5 * (self.d * LOG_2_PI) - - self.register_buffer("_inv_cov", _inv_cov) - self.register_buffer("_inv_cov_dot_mu", _inv_cov_dot_mu) - self.register_buffer("_log_det", _log_det) - self.register_buffer("_theta", _theta) - - elif self.covariance_type in ('diag', 'sphere'): - self.register_buffer("_xxw_sum", torch.zeros(self.d, - dtype=self.dtype, device=self.device)) - - if self.covs.sum() > 0.0: - _log_sigma_sqrt_2pi = -torch.log(torch.sqrt(self.covs) * - SQRT_2_PI) - _inv_two_sigma = 1. / (2 * self.covs) - - self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi) - self.register_buffer("_inv_two_sigma", _inv_two_sigma) - - if any(self.covs < 0): - raise ValueError("Variances must be positive.") + self.name = "LogNormal" + super().__init__(means=means, covs=covs, covariance_type=covariance_type, + min_cov=min_cov, inertia=inertia, frozen=frozen, check_data=check_data) def sample(self, n): """Sample from the probability distribution. @@ -224,15 +132,9 @@ def log_probability(self, X): # take the log of X x_log = X.log() - if self.covariance_type == 'full': - logp = torch.matmul(x_log, self._inv_cov) - self._inv_cov_dot_mu - logp = self.d * LOG_2_PI + torch.sum(logp ** 2, dim=-1) - logp = self._log_det - 0.5 * logp - return logp - - elif self.covariance_type in ('diag', 'sphere'): - return torch.sum(self._log_sigma_sqrt_2pi - ((x_log - self.means) ** 2) - * self._inv_two_sigma, dim=-1) + return super().log_probability( + x_log + ) def summarize(self, X, sample_weight=None): """Extract the sufficient statistics from a batch of data. @@ -256,48 +158,5 @@ def summarize(self, X, sample_weight=None): if self.frozen is True: return - - X = _cast_as_tensor(X, dtype=self.means.dtype) - X = X.log() - - X, sample_weight = super().summarize(X, sample_weight=sample_weight) X = _cast_as_tensor(X, dtype=self.means.dtype) - - if self.covariance_type == 'full': - self._w_sum += torch.sum(sample_weight, dim=0) - self._xw_sum += torch.sum(X * sample_weight, axis=0) - self._xxw_sum += torch.matmul((X * sample_weight).T, X) - - elif self.covariance_type in ('diag', 'sphere'): - self._w_sum[:] = self._w_sum + torch.sum(sample_weight, dim=0) - self._xw_sum[:] = self._xw_sum + torch.sum(X * sample_weight, dim=0) - self._xxw_sum[:] = self._xxw_sum + torch.sum(X ** 2 * - sample_weight, dim=0) - - def from_summaries(self): - """Update the model parameters given the extracted statistics. - - This method uses calculated statistics from calls to the `summarize` - method to update the distribution parameters. Hyperparameters for the - update are passed in at initialization time. - - Note: Internally, a call to `fit` is just a successive call to the - `summarize` method followed by the `from_summaries` method. - """ - - if self.frozen == True: - return - - means = self._xw_sum / self._w_sum - - if self.covariance_type == 'full': - v = self._xw_sum.unsqueeze(0) * self._xw_sum.unsqueeze(1) - covs = self._xxw_sum / self._w_sum - v / self._w_sum ** 2.0 - - elif self.covariance_type == 'diag': - covs = self._xxw_sum / self._w_sum - \ - self._xw_sum ** 2.0 / self._w_sum ** 2.0 - - _update_parameter(self.means, means, self.inertia) - _update_parameter(self.covs, covs, self.inertia) - self._reset_cache() + super().summarize(X.log(), sample_weight=sample_weight) diff --git a/pomegranate/distributions/lognormal_old.pyx b/pomegranate/distributions/lognormal_old.pyx deleted file mode 100644 index 7cd4a8e5..00000000 --- a/pomegranate/distributions/lognormal_old.pyx +++ /dev/null @@ -1,112 +0,0 @@ -#!python -#cython: boundscheck=False -#cython: cdivision=True -# LogNormalDistribution.pyx -# Contact: Jacob Schreiber - -import numpy - -from ..utils cimport _log -from ..utils cimport isnan -from ..utils import check_random_state - -from libc.math cimport sqrt as csqrt - -# Define some useful constants -DEF NEGINF = float("-inf") -DEF INF = float("inf") -DEF SQRT_2_PI = 2.50662827463 -DEF LOG_2_PI = 1.83787706641 - -cdef class LogNormalDistribution(Distribution): - """A lognormal distribution over non-negative floats. - - The parameters are the mu and sigma of the normal distribution, which - is the the exponential of the log normal distribution. - """ - - property parameters: - def __get__(self): - return [self.mu, self.sigma] - def __set__(self, parameters): - self.mu, self.sigma = parameters - - def __init__(self, double mu, double sigma, double min_std=0.0, frozen=False): - self.mu = mu - self.sigma = sigma - self.summaries = [0, 0, 0] - self.name = "LogNormalDistribution" - self.frozen = frozen - self.min_std = min_std - - def __reduce__(self): - """Serialize distribution for pickling.""" - return self.__class__, (self.mu, self.sigma, self.frozen) - - cdef void _log_probability(self, double* X, double* log_probability, int n) nogil: - cdef int i - for i in range(n): - if isnan(X[i]): - log_probability[i] = 0. - else: - log_probability[i] = -_log(X[i] * self.sigma * SQRT_2_PI) - 0.5\ - * ((_log(X[i]) - self.mu) / self.sigma) ** 2 - - def sample(self, n=None, random_state=None): - random_state = check_random_state(random_state) - return random_state.lognormal(self.mu, self.sigma, n) - - cdef double _summarize(self, double* items, double* weights, int n, - int column_idx, int d) nogil: - """Cython function to get the MLE estimate for a Gaussian.""" - - cdef int i - cdef double x_sum = 0.0, x2_sum = 0.0, w_sum = 0.0 - cdef double item, log_item - - for i in range(n): - item = items[i*d + column_idx] - if isnan(item): - continue - - log_item = _log(item) - w_sum += weights[i] - x_sum += weights[i] * log_item - x2_sum += weights[i] * log_item * log_item - - with gil: - self.summaries[0] += w_sum - self.summaries[1] += x_sum - self.summaries[2] += x2_sum - - def from_summaries(self, inertia=0.0): - """ - Takes in a series of summaries, represented as a mean, a variance, and - a weight, and updates the underlying distribution. Notes on how to do - this for a Gaussian distribution were taken from here: - http://math.stackexchange.com/questions/453113/how-to-merge-two-gaussians - """ - - # If no summaries stored or the summary is frozen, don't do anything. - if self.summaries[0] == 0 or self.frozen == True: - return - - mu = self.summaries[1] / self.summaries[0] - var = self.summaries[2] / self.summaries[0] - self.summaries[1] ** 2.0 / self.summaries[0] ** 2.0 - - sigma = csqrt(var) - if sigma < self.min_std: - sigma = self.min_std - - self.mu = self.mu*inertia + mu*(1-inertia) - self.sigma = self.sigma*inertia + sigma*(1-inertia) - self.summaries = [0, 0, 0] - - def clear_summaries(self): - """Clear the summary statistics stored in the object.""" - - self.summaries = [0, 0, 0] - - @classmethod - def blank(cls): - return cls(0, 1)