diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index e96662168..a5af0406d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2913,10 +2913,6 @@ def support(self): return self._support def _log_Z(self): - """Computes the logarithm of normalization constant. - - :return: The logarithm of normalization constant. - """ return jnp.where( self.alpha == -1.0, jnp.log(jnp.log(self.low) - jnp.log(self.high)), @@ -2994,3 +2990,73 @@ def variance(self): ) return expectation_x_squared - expectation_x**2 + + +class LowerTruncatedPowerLaw(Distribution): + r"""Lower truncated power law distribution with :math:`\alpha` index. + + :param alpha: index of the power law distribution + :param low: lower bound of the distribution + """ + + arg_constraints = { + "alpha": constraints.less_than(-1.0), + "low": constraints.greater_than(1.0), + } + reparametrized_params = ["alpha", "low"] + pytree_aux_fields = ("_support",) + + def __init__(self, alpha, low, *, validate_args=None): + self.alpha, self.low = promote_shapes(alpha, low) + batch_shape = lax.broadcast_shapes(jnp.shape(alpha), jnp.shape(low)) + self._support = constraints.greater_than(low) + super(LowerTruncatedPowerLaw, self).__init__( + batch_shape=batch_shape, validate_args=validate_args + ) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return self._support + + @validate_sample + def log_prob(self, value): + return ( + self.alpha * jnp.log(value) + - (self.alpha + 1.0) * jnp.log(self.low) + + jnp.log(-1.0 - self.alpha) + ) + + def cdf(self, value): + return 1.0 - jnp.power(value / self.low, 1.0 + self.alpha) + + def icdf(self, q): + return self.low * jnp.power(1.0 - q, jnp.reciprocal(1.0 + self.alpha)) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + u = random.uniform(key, sample_shape + self.batch_shape) + samples = self.icdf(u) + return samples + + @lazy_property + def mean(self): + if self.alpha + 2.0 >= 0: + return jnp.full_like(self.low, jnp.inf) + return ( + (self.alpha + 1.0) + * jnp.reciprocal(self.alpha + 2.0) + * (self.low - jnp.power(self.low, -self.alpha - 1.0)) + ) + + @lazy_property + def variance(self): + if self.alpha + 2.0 >= 0: + return jnp.full_like(self.low, jnp.inf) + if self.alpha + 3.0 >= 0: + return jnp.full_like(self.low, jnp.inf) + return ( + jnp.power(self.low, 2.0) + * (self.alpha + 1.0) + * jnp.power(self.alpha, -2.0) + * jnp.reciprocal(self.alpha + 3.0) + )