Skip to content

Commit

Permalink
feat: halfnormal and lognormal
Browse files Browse the repository at this point in the history
Signed-off-by: mschrader15 <[email protected]>
  • Loading branch information
mschrader15 committed Sep 25, 2023
1 parent 8b62c66 commit 70c7dd7
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 365 deletions.
2 changes: 2 additions & 0 deletions pomegranate/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
from .student_t import StudentT
from .uniform import Uniform
from .zero_inflated import ZeroInflated
from .lognormal import LogNormal
from .halfnormal import HalfNormal
121 changes: 18 additions & 103 deletions pomegranate/distributions/halfnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
from .._utils import _check_shapes

from ._distribution import Distribution
from .normal import Normal


# Define some useful constants
NEGINF = float("-inf")
INF = float("inf")
SQRT_2_PI = 2.50662827463
LOG_2_PI = 1.83787706641
LOG_2 = 0.6931471805599453


class HalfNormal(Distribution):
class HalfNormal(Normal):
"""A half-normal distribution object.
A half-normal distribution is a distribution over positive real numbers that
Expand Down Expand Up @@ -78,20 +80,12 @@ def __init__(
frozen=False,
check_data=True,
):
super().__init__(inertia=inertia, frozen=frozen, check_data=check_data)
self.name = "HalfNormal"

self.covs = _check_parameter(_cast_as_parameter(covs), "covs", ndim=(1, 2))

_check_shapes([self.means, self.covs], ["means", "covs"])

self.min_cov = _check_parameter(min_cov, "min_cov", min_value=0, ndim=0)
self.covariance_type = covariance_type

self._initialized = covs is not None
self.d = self.means.shape[-1] if self._initialized else None
self._reset_cache()

self.name = "HalfNormal"
super().__init__(means=None, covs=covs, min_cov=min_cov,
covariance_type=covariance_type, inertia=inertia, frozen=frozen,
check_data=check_data)

def _initialize(self, d):
"""Initialize the probability distribution.
Expand All @@ -105,20 +99,6 @@ def _initialize(self, d):
d: int
The dimensionality the distribution is being initialized to.
"""
if self.covariance_type == "full":
self.covs = _cast_as_parameter(
torch.zeros(d, d, dtype=self.dtype, device=self.device)
)
elif self.covariance_type == "diag":
self.covs = _cast_as_parameter(
torch.zeros(d, dtype=self.dtype, device=self.device)
)
elif self.covariance_type == "sphere":
self.covs = _cast_as_parameter(
torch.tensor(0, dtype=self.dtype, device=self.device)
)

self._initialized = True
super()._initialize(d)

def _reset_cache(self):
Expand All @@ -129,51 +109,7 @@ def _reset_cache(self):
recalculates the cached values meant to speed up log probability
calculations.
"""

if self._initialized == False:
return

self.register_buffer(
"_w_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device)
)
self.register_buffer(
"_xw_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device)
)

if self.covariance_type == "full":
self.register_buffer(
"_xxw_sum",
torch.zeros(self.d, self.d, dtype=self.dtype, device=self.device),
)

if self.covs.sum() > 0.0:
chol = torch.linalg.cholesky(self.covs)
_inv_cov = torch.linalg.solve_triangular(
chol,
torch.eye(len(self.covs), dtype=self.dtype, device=self.device),
upper=False,
).T
_log_det = -0.5 * torch.linalg.slogdet(self.covs)[1]
_theta = _log_det - 0.5 * (self.d * LOG_2_PI)

self.register_buffer("_inv_cov", _inv_cov)
self.register_buffer("_log_det", _log_det)
self.register_buffer("_theta", _theta)

elif self.covariance_type in ("diag", "sphere"):
self.register_buffer(
"_xxw_sum", torch.zeros(self.d, dtype=self.dtype, device=self.device)
)

if self.covs.sum() > 0.0:
_log_sigma_sqrt_2pi = -torch.log(torch.sqrt(self.covs) * SQRT_2_PI)
_inv_two_sigma = 1.0 / (2 * self.covs)

self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi)
self.register_buffer("_inv_two_sigma", _inv_two_sigma)

if torch.any(self.covs < 0):
raise ValueError("Variances must be positive.")
super()._reset_cache()

def sample(self, n):
"""Sample from the probability distribution.
Expand All @@ -193,10 +129,7 @@ def sample(self, n):
X: torch.tensor, shape=(n, self.d)
Randomly generated samples.
"""

if self.covariance_type == "diag":
return torch.distributions.HalfNormal(self.covs).sample([n])
elif self.covariance_type == "full":
if self.covariance_type in ["diag", "full"]:
return torch.distributions.HalfNormal(self.covs).sample([n])

def log_probability(self, X):
Expand Down Expand Up @@ -225,23 +158,15 @@ def log_probability(self, X):
"""

X = _check_parameter(
_cast_as_tensor(X, dtype=self.means.dtype),
_cast_as_tensor(X, dtype=self.covs.dtype),
"X",
ndim=2,
shape=(-1, self.d),
check_parameter=self.check_data,
)
return super().log_probability(X) + LOG_2


# if self.covariance_type == 'full':
# # logp = torch.matmul(X, self._inv_cov) - self._inv_cov_dot_mu
# # logp = self.d * LOG_2_PI + torch.sum(logp ** 2, dim=-1)
# # logp = self._log_det - 0.5 * logp
# # return logp
return 0.5 * LOG_2_PI + (X**2 / 2).sum(dim=-1)

# elif self.covariance_type in ('diag', 'sphere'):
# return torch.sum(self._log_sigma_sqrt_2pi - ((X - self.means) ** 2)
# * self._inv_two_sigma, dim=-1)

def summarize(self, X, sample_weight=None):
"""Extract the sufficient statistics from a batch of data.
Expand All @@ -263,21 +188,7 @@ def summarize(self, X, sample_weight=None):
(-1, self.d) or a vector of shape (-1,). Default is ones.
"""

if self.frozen == True:
return

X, sample_weight = super().summarize(X, sample_weight=sample_weight)
X = _cast_as_tensor(X, dtype=self.means.dtype)

if self.covariance_type == "full":
self._w_sum += torch.sum(sample_weight, dim=0)
self._xw_sum += torch.sum(X * sample_weight, axis=0)
self._xxw_sum += torch.matmul((X * sample_weight).T, X)

elif self.covariance_type in ("diag", "sphere"):
self._w_sum[:] = self._w_sum + torch.sum(sample_weight, dim=0)
self._xw_sum[:] = self._xw_sum + torch.sum(X * sample_weight, dim=0)
self._xxw_sum[:] = self._xxw_sum + torch.sum(X**2 * sample_weight, dim=0)
super().summarize(X, sample_weight=sample_weight)

def from_summaries(self):
"""Update the model parameters given the extracted statistics.
Expand All @@ -293,6 +204,9 @@ def from_summaries(self):
if self.frozen == True:
return

# the means are always zero for a half normal distribution
means = torch.zeros(self.d, dtype=self.covs.dtype)

if self.covariance_type == "full":
v = self._xw_sum.unsqueeze(0) * self._xw_sum.unsqueeze(1)
covs = self._xxw_sum / self._w_sum - v / self._w_sum**2.0
Expand All @@ -305,4 +219,5 @@ def from_summaries(self):
covs = covs.mean(dim=-1)

_update_parameter(self.covs, covs, self.inertia)
_update_parameter(self.means, means, self.inertia)
self._reset_cache()
159 changes: 9 additions & 150 deletions pomegranate/distributions/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .._utils import _check_parameter
from .._utils import _check_shapes

from ._distribution import Distribution
from .normal import Normal


# Define some useful constants
Expand All @@ -19,7 +19,7 @@
LOG_2_PI = 1.83787706641


class LogNormal(Distribution):
class LogNormal(Normal):
"""A lognormal object.
The parameters are the mu and sigma of the normal distribution, which
Expand Down Expand Up @@ -71,102 +71,10 @@ class LogNormal(Distribution):

def __init__(self, means=None, covs=None, covariance_type='full',
min_cov=None, inertia=0.0, frozen=False, check_data=True):
super().__init__(inertia=inertia, frozen=frozen, check_data=check_data)
self.name = "LogNormal"

self.means = _check_parameter(_cast_as_parameter(means), "means",
ndim=1)
self.covs = _check_parameter(_cast_as_parameter(covs), "covs",
ndim=(1, 2))

_check_shapes([self.means, self.covs], ["means", "covs"])

self.min_cov = _check_parameter(min_cov, "min_cov", min_value=0, ndim=0)
self.covariance_type = covariance_type

self._initialized = (means is not None) and (covs is not None)
self.d = self.means.shape[-1] if self._initialized else None
self._reset_cache()

def _initialize(self, d):
"""Initialize the probability distribution.
This method is meant to only be called internally. It initializes the
parameters of the distribution and stores its dimensionality. For more
complex methods, this function will do more.
Parameters
----------
d: int
The dimensionality the distribution is being initialized to.
"""

self.means = _cast_as_parameter(torch.zeros(d, dtype=self.dtype,
device=self.device))

if self.covariance_type == 'full':
self.covs = _cast_as_parameter(torch.zeros(d, d,
dtype=self.dtype, device=self.device))
elif self.covariance_type == 'diag':
self.covs = _cast_as_parameter(torch.zeros(d, dtype=self.dtype,
device=self.device))
elif self.covariance_type == 'sphere':
self.covs = _cast_as_parameter(torch.tensor(0, dtype=self.dtype,
device=self.device))

self._initialized = True
super()._initialize(d)

def _reset_cache(self):
"""Reset the internally stored statistics.
This method is meant to only be called internally. It resets the
stored statistics used to update the model parameters as well as
recalculates the cached values meant to speed up log probability
calculations.
"""

if self._initialized == False:
return

self.register_buffer("_w_sum", torch.zeros(self.d, dtype=self.dtype,
device=self.device))
self.register_buffer("_xw_sum", torch.zeros(self.d, dtype=self.dtype,
device=self.device))

if self.covariance_type == 'full':
self.register_buffer("_xxw_sum", torch.zeros(self.d, self.d,
dtype=self.dtype, device=self.device))

if self.covs.sum() > 0.0:
chol = torch.linalg.cholesky(self.covs)
_inv_cov = torch.linalg.solve_triangular(chol, torch.eye(
len(self.covs), dtype=self.dtype, device=self.device),
upper=False).T
_inv_cov_dot_mu = torch.matmul(self.means, _inv_cov)
_log_det = -0.5 * torch.linalg.slogdet(self.covs)[1]
_theta = _log_det - 0.5 * (self.d * LOG_2_PI)

self.register_buffer("_inv_cov", _inv_cov)
self.register_buffer("_inv_cov_dot_mu", _inv_cov_dot_mu)
self.register_buffer("_log_det", _log_det)
self.register_buffer("_theta", _theta)

elif self.covariance_type in ('diag', 'sphere'):
self.register_buffer("_xxw_sum", torch.zeros(self.d,
dtype=self.dtype, device=self.device))

if self.covs.sum() > 0.0:
_log_sigma_sqrt_2pi = -torch.log(torch.sqrt(self.covs) *
SQRT_2_PI)
_inv_two_sigma = 1. / (2 * self.covs)

self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi)
self.register_buffer("_inv_two_sigma", _inv_two_sigma)

if any(self.covs < 0):
raise ValueError("Variances must be positive.")
self.name = "LogNormal"
super().__init__(means=means, covs=covs, covariance_type=covariance_type,
min_cov=min_cov, inertia=inertia, frozen=frozen, check_data=check_data)

def sample(self, n):
"""Sample from the probability distribution.
Expand Down Expand Up @@ -224,15 +132,9 @@ def log_probability(self, X):
# take the log of X
x_log = X.log()

if self.covariance_type == 'full':
logp = torch.matmul(x_log, self._inv_cov) - self._inv_cov_dot_mu
logp = self.d * LOG_2_PI + torch.sum(logp ** 2, dim=-1)
logp = self._log_det - 0.5 * logp
return logp

elif self.covariance_type in ('diag', 'sphere'):
return torch.sum(self._log_sigma_sqrt_2pi - ((x_log - self.means) ** 2)
* self._inv_two_sigma, dim=-1)
return super().log_probability(
x_log
)

def summarize(self, X, sample_weight=None):
"""Extract the sufficient statistics from a batch of data.
Expand All @@ -256,48 +158,5 @@ def summarize(self, X, sample_weight=None):

if self.frozen is True:
return

X = _cast_as_tensor(X, dtype=self.means.dtype)
X = X.log()

X, sample_weight = super().summarize(X, sample_weight=sample_weight)
X = _cast_as_tensor(X, dtype=self.means.dtype)

if self.covariance_type == 'full':
self._w_sum += torch.sum(sample_weight, dim=0)
self._xw_sum += torch.sum(X * sample_weight, axis=0)
self._xxw_sum += torch.matmul((X * sample_weight).T, X)

elif self.covariance_type in ('diag', 'sphere'):
self._w_sum[:] = self._w_sum + torch.sum(sample_weight, dim=0)
self._xw_sum[:] = self._xw_sum + torch.sum(X * sample_weight, dim=0)
self._xxw_sum[:] = self._xxw_sum + torch.sum(X ** 2 *
sample_weight, dim=0)

def from_summaries(self):
"""Update the model parameters given the extracted statistics.
This method uses calculated statistics from calls to the `summarize`
method to update the distribution parameters. Hyperparameters for the
update are passed in at initialization time.
Note: Internally, a call to `fit` is just a successive call to the
`summarize` method followed by the `from_summaries` method.
"""

if self.frozen == True:
return

means = self._xw_sum / self._w_sum

if self.covariance_type == 'full':
v = self._xw_sum.unsqueeze(0) * self._xw_sum.unsqueeze(1)
covs = self._xxw_sum / self._w_sum - v / self._w_sum ** 2.0

elif self.covariance_type == 'diag':
covs = self._xxw_sum / self._w_sum - \
self._xw_sum ** 2.0 / self._w_sum ** 2.0

_update_parameter(self.means, means, self.inertia)
_update_parameter(self.covs, covs, self.inertia)
self._reset_cache()
super().summarize(X.log(), sample_weight=sample_weight)
Loading

0 comments on commit 70c7dd7

Please sign in to comment.