Skip to content

Commit

Permalink
implementation of DoublyTruncatedPowerLaw
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash committed May 30, 2024
1 parent e8216d7 commit febd260
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Cauchy,
Chi2,
Dirichlet,
DoublyTruncatedPowerLaw,
EulerMaruyama,
Exponential,
Gamma,
Expand Down Expand Up @@ -129,6 +130,7 @@
"Chi2",
"Delta",
"Dirichlet",
"DoublyTruncatedPowerLaw",
"DirichletMultinomial",
"DiscreteUniform",
"Distribution",
Expand Down
115 changes: 115 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,3 +2879,118 @@ def infer_shapes(
batch_shape = lax.broadcast_shapes(concentration, matrix[:-2])
event_shape = matrix[-2:]
return batch_shape, event_shape


class DoublyTruncatedPowerLaw(Distribution):
r"""Power law distribution with :math:`\alpha` index, and lower and upper bounds.
:param alpha: index of the power law distribution
:param low: lower bound of the distribution
:param high: upper bound of the distribution
"""

arg_constraints = {
"alpha": constraints.real,
"low": constraints.positive,
"high": constraints.positive,
}
reparametrized_params = ["alpha", "low", "high"]
pytree_aux_fields = ("_support", "_logZ")

def __init__(self, alpha, low, high, *, validate_args=None):
self.alpha, self.low, self.high = promote_shapes(alpha, low, high)
self._support = constraints.interval(low, high)
batch_shape = lax.broadcast_shapes(
jnp.shape(alpha), jnp.shape(low), jnp.shape(high)
)
super(DoublyTruncatedPowerLaw, self).__init__(
batch_shape=batch_shape, validate_args=validate_args
)
self._logZ = self._log_Z()

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
return self._support

def _log_Z(self):
"""Computes the logarithm of normalization constant.
:return: The logarithm of normalization constant.
"""
return jnp.where(
self.alpha == -1.0,
jnp.log(jnp.log(self.low) - jnp.log(self.high)),
jnp.log(
jnp.abs(
jnp.power(self.low, 1.0 + self.alpha)
- jnp.power(self.high, 1.0 + self.alpha)
)
)
- jnp.log(jnp.abs(1.0 + self.alpha)),
)

@validate_sample
def log_prob(self, value):
return self.alpha * jnp.log(value) - self._logZ

def cdf(self, value):
return jnp.where(
self.alpha == -1.0,
(jnp.log(value) - jnp.log(self.low))
/ (jnp.log(self.high) - jnp.log(self.low)),
(jnp.power(value, 1.0 + self.alpha) - jnp.power(self.low, 1.0 + self.alpha))
/ (
jnp.power(self.high, 1.0 + self.alpha)
- jnp.power(self.low, 1.0 + self.alpha)
),
)

def icdf(self, q):
return jnp.where(
self.alpha == -1.0,
jnp.exp(jnp.log(self.low) + q * (jnp.log(self.high) - jnp.log(self.low))),
jnp.power(
jnp.power(self.low, 1.0 + self.alpha)
+ q
* (
jnp.power(self.high, 1.0 + self.alpha)
- jnp.power(self.low, 1.0 + self.alpha)
),
jnp.reciprocal(1.0 + self.alpha),
),
)

def sample(self, key, sample_shape=()):
assert is_prng_key(key)
u = random.uniform(key, sample_shape + self.batch_shape)
samples = self.icdf(u)
return samples

@lazy_property
def mean(self):
Z = jnp.exp(self._logZ)
return jnp.where(
self.alpha == -2.0,
(jnp.log(self.high) - jnp.log(self.low)) / Z,
(
jnp.power(self.high, 2.0 + self.alpha)
- jnp.power(self.low, 2.0 + self.alpha)
)
/ ((2.0 + self.alpha) * Z),
)

@lazy_property
def variance(self):
Z = jnp.exp(self._logZ)
expectation_x = self.mean
expectation_x_squared = jnp.where(
self.alpha == -3.0,
(jnp.log(self.high) - jnp.log(self.low)) / Z,
(
jnp.power(self.high, 3.0 + self.alpha)
- jnp.power(self.low, 3.0 + self.alpha)
)
/ ((3.0 + self.alpha) * Z),
)

return expectation_x_squared - expectation_x**2

0 comments on commit febd260

Please sign in to comment.