Skip to content

Commit

Permalink
implementation of LowerTruncatedPowerLaw
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed May 30, 2024
1 parent febd260 commit dbb0d62
Showing 1 changed file with 70 additions and 4 deletions.
74 changes: 70 additions & 4 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2913,10 +2913,6 @@ def support(self):
return self._support

def _log_Z(self):
"""Computes the logarithm of normalization constant.
:return: The logarithm of normalization constant.
"""
return jnp.where(
self.alpha == -1.0,
jnp.log(jnp.log(self.low) - jnp.log(self.high)),
Expand Down Expand Up @@ -2994,3 +2990,73 @@ def variance(self):
)

return expectation_x_squared - expectation_x**2


class LowerTruncatedPowerLaw(Distribution):
r"""Lower truncated power law distribution with :math:`\alpha` index.
:param alpha: index of the power law distribution
:param low: lower bound of the distribution
"""

arg_constraints = {
"alpha": constraints.less_than(-1.0),
"low": constraints.greater_than(1.0),
}
reparametrized_params = ["alpha", "low"]
pytree_aux_fields = ("_support",)

def __init__(self, alpha, low, *, validate_args=None):
self.alpha, self.low = promote_shapes(alpha, low)
batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low))
self._support = constraints.greater_than(low)
super(LowerTruncatedPowerLaw, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support

@validate_sample
def log_prob(self, value):
return (
self.alpha * jnp.log(value)
- (self.alpha + 1.0) * jnp.log(self.low)
+ jnp.log(-1.0 - self.alpha)
)

def cdf(self, value):
return 1.0 - jnp.power(value / self.low, 1.0 + self.alpha)

def icdf(self, q):
return self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha))

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
samples = self.icdf(u)
return samples

@lazy_property
def mean(self):
if self.alpha + 2.0 >= 0:
return jnp.full_like(self.low, jnp.inf)
return (
(self.alpha + 1.0)
* jnp.reciprocal(self.alpha + 2.0)
* (self.low - jnp.power(self.low, -self.alpha - 1.0))
)

@lazy_property
def variance(self):
if self.alpha + 2.0 >= 0:
return jnp.full_like(self.low, jnp.inf)
if self.alpha + 3.0 >= 0:
return jnp.full_like(self.low, jnp.inf)
return (
jnp.power(self.low, 2.0)
* (self.alpha + 1.0)
* jnp.power(self.alpha, -2.0)
* jnp.reciprocal(self.alpha + 3.0)
)

0 comments on commit dbb0d62

Please sign in to comment.