Skip to content

Commit

Permalink
add PDFLike protocol to mirror tensorflow probability and distrax Dis…
Browse files Browse the repository at this point in the history
…tributions
  • Loading branch information
pfackeldey committed Mar 20, 2024
1 parent 06206a2 commit 183a511
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 34 deletions.
15 changes: 13 additions & 2 deletions src/evermore/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol
from typing import TYPE_CHECKING, Any, NamedTuple, Protocol, runtime_checkable

from jaxtyping import Array
from jaxtyping import Array, PRNGKeyArray

if TYPE_CHECKING:
from evermore.modifier import compose
Expand All @@ -13,6 +13,7 @@
"SF",
"AddOrMul",
"ModifierLike",
"PDFLike",
]


Expand Down Expand Up @@ -43,3 +44,13 @@ class ModifierLike(Protocol):
def scale_factor(self, hist: Array) -> SF: ...
def __call__(self, hist: Array) -> Array: ...
def __matmul__(self, other: ModifierLike) -> compose: ...


@runtime_checkable
class PDFLike(Protocol):
"""Mirrors the (relevant) interface of `tfp.distributions.Distribution` & `distrax.Distribution`."""

def log_prob(self, x: Array) -> Array: ...
def prob(self, x: Array) -> Array: ...
def cdf(self, x: Array) -> Array: ...
def sample(self, key: PRNGKeyArray) -> Array: ...
9 changes: 4 additions & 5 deletions src/evermore/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import jax.numpy as jnp
from jaxtyping import Array, PyTree

from evermore.custom_types import _NoValue
from evermore.custom_types import PDFLike
from evermore.parameter import Parameter
from evermore.pdf import PDF
from evermore.util import _params_map

__all__ = [
Expand All @@ -25,9 +24,9 @@ def __dir__():
def get_logpdf_constraints(module: PyTree) -> PyTree:
def _constraint(param: Parameter) -> Array:
constraint = param.constraint
if constraint is not _NoValue:
constraint = cast(PDF, constraint)
return constraint.logpdf(param.value)
if isinstance(constraint, PDFLike):
constraint = cast(PDFLike, constraint)
return constraint.log_prob(param.value)
return jnp.array([0.0])

# constraints from pdfs
Expand Down
8 changes: 4 additions & 4 deletions src/evermore/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike

from evermore.custom_types import Sentinel, _NoValue
from evermore.pdf import PDF, Flat, Normal, Poisson
from evermore.custom_types import PDFLike, Sentinel, _NoValue
from evermore.pdf import Flat, Normal, Poisson
from evermore.util import as1darray

if TYPE_CHECKING:
Expand All @@ -30,14 +30,14 @@ class Parameter(eqx.Module):
value: Array = eqx.field(converter=as1darray)
lower: Array = eqx.field(converter=as1darray)
upper: Array = eqx.field(converter=as1darray)
constraint: PDF | Sentinel = eqx.field(static=True)
constraint: PDFLike | Sentinel = eqx.field(static=True)

def __init__(
self,
value: ArrayLike,
lower: ArrayLike,
upper: ArrayLike,
constraint: PDF | Sentinel = _NoValue,
constraint: PDFLike | Sentinel = _NoValue,
) -> None:
self.value = as1darray(value)
self.lower = jnp.broadcast_to(as1darray(lower), self.value.shape)
Expand Down
16 changes: 8 additions & 8 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def __dir__():

class PDF(eqx.Module):
@abstractmethod
def logpdf(self, x: Array) -> Array: ...
def log_prob(self, x: Array) -> Array: ...

@abstractmethod
def pdf(self, x: Array) -> Array: ...
def prob(self, x: Array) -> Array: ...

@abstractmethod
def cdf(self, x: Array) -> Array: ...
Expand All @@ -34,10 +34,10 @@ def sample(self, key: PRNGKeyArray) -> Array: ...


class Flat(PDF):
def logpdf(self, x: Array) -> Array:
def log_prob(self, x: Array) -> Array:
return jnp.zeros_like(x)

def pdf(self, x: Array) -> Array:
def prob(self, x: Array) -> Array:
return jnp.ones_like(x)

def cdf(self, x: Array) -> Array:
Expand All @@ -56,14 +56,14 @@ class Normal(PDF):
mean: Array
width: Array

def logpdf(self, x: Array) -> Array:
def log_prob(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.norm.logpdf(
self.mean, loc=self.mean, scale=self.width
)
unnormalized = jax.scipy.stats.norm.logpdf(x, loc=self.mean, scale=self.width)
return unnormalized - logpdf_max

def pdf(self, x: Array) -> Array:
def prob(self, x: Array) -> Array:
return jax.scipy.stats.norm.pdf(x, loc=self.mean, scale=self.width)

def cdf(self, x: Array) -> Array:
Expand All @@ -77,12 +77,12 @@ def sample(self, key: PRNGKeyArray) -> Array:
class Poisson(PDF):
lamb: Array

def logpdf(self, x: Array) -> Array:
def log_prob(self, x: Array) -> Array:
logpdf_max = jax.scipy.stats.poisson.logpmf(self.lamb, mu=self.lamb)
unnormalized = jax.scipy.stats.poisson.logpmf((x + 1) * self.lamb, mu=self.lamb)
return unnormalized - logpdf_max

def pdf(self, x: Array) -> Array:
def prob(self, x: Array) -> Array:
return jax.scipy.stats.poisson.pmf((x + 1) * self.lamb, mu=self.lamb)

def cdf(self, x: Array) -> Array:
Expand Down
9 changes: 4 additions & 5 deletions src/evermore/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import jax
from jaxtyping import Array, PRNGKeyArray, PyTree

from evermore.custom_types import _NoValue
from evermore.pdf import PDF
from evermore.custom_types import PDFLike
from evermore.util import is_parameter


Expand All @@ -22,12 +21,12 @@ def toy_module(module: eqx.Module, key: PRNGKeyArray) -> PyTree[Callable]:
keys_tree = jax.tree_util.tree_unflatten(params_structure, keys)

def _sample(param: Parameter, key: Parameter) -> Array:
if param.constraint is _NoValue:
msg = f"Parameter {param} has no constraint pdf, can't sample from it."
if not isinstance(param.constraint, PDFLike):
msg = f"Parameter {param} has no sampling method, can't sample from it."
raise RuntimeError(msg)

pdf = param.constraint
pdf = cast(PDF, pdf)
pdf = cast(PDFLike, pdf)

# sample new value from the constraint pdf
sampled_param_value = pdf.sample(key.value)
Expand Down
20 changes: 10 additions & 10 deletions tests/test_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
def test_flat():
pdf = Flat()

assert pdf.pdf(jnp.array(1.0)) == jnp.array(1.0)
assert pdf.pdf(jnp.array(2.0)) == jnp.array(1.0)
assert pdf.pdf(jnp.array(3.0)) == jnp.array(1.0)
assert pdf.prob(jnp.array(1.0)) == jnp.array(1.0)
assert pdf.prob(jnp.array(2.0)) == jnp.array(1.0)
assert pdf.prob(jnp.array(3.0)) == jnp.array(1.0)

assert pdf.logpdf(jnp.array(1.0)) == jnp.array(0.0)
assert pdf.logpdf(jnp.array(2.0)) == jnp.array(0.0)
assert pdf.logpdf(jnp.array(3.0)) == jnp.array(0.0)
assert pdf.log_prob(jnp.array(1.0)) == jnp.array(0.0)
assert pdf.log_prob(jnp.array(2.0)) == jnp.array(0.0)
assert pdf.log_prob(jnp.array(3.0)) == jnp.array(0.0)


def test_Normal():
pdf = Normal(mean=jnp.array(0.0), width=jnp.array(1.0))

assert pdf.pdf(jnp.array(0.0)) == pytest.approx(1.0 / jnp.sqrt(2 * jnp.pi))
assert pdf.logpdf(jnp.array(0.0)) == pytest.approx(0.0)
assert pdf.prob(jnp.array(0.0)) == pytest.approx(1.0 / jnp.sqrt(2 * jnp.pi))
assert pdf.log_prob(jnp.array(0.0)) == pytest.approx(0.0)
assert pdf.cdf(jnp.array(0.0)) == pytest.approx(0.5)


def test_poisson():
pdf = Poisson(lamb=jnp.array(10))

assert pdf.pdf(jnp.array(0)) == pytest.approx(0.12510978)
assert pdf.logpdf(jnp.array(-0.5)) == pytest.approx(-1.196003)
assert pdf.prob(jnp.array(0)) == pytest.approx(0.12510978)
assert pdf.log_prob(jnp.array(-0.5)) == pytest.approx(-1.196003)
assert pdf.cdf(jnp.array(0)) == pytest.approx(0.5830412)

0 comments on commit 183a511

Please sign in to comment.