Skip to content

Commit

Permalink
Updated linting rules and ensured compliance
Browse files Browse the repository at this point in the history
  • Loading branch information
tttc3 committed Feb 24, 2024
1 parent fa5f493 commit 525480a
Show file tree
Hide file tree
Showing 25 changed files with 181 additions and 120 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@ pip install mccube
```
Requires Python 3.9+, Diffrax 0.5.0+, and Equinox 0.11.3+.

By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA
By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA
compatible accelerators (GPUs/TPUs) please follow [these installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier).
Windows support for JAX is currently experimental; WSL2 is the recommended approach for
Windows support for JAX is currently experimental; WSL2 is the recommended approach for
using JAX on Windows.

## Documentation
Available at [https://mccube.readthedocs.io/](https://mccube.readthedocs.io/).

## What is Markov chain cubature?
MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098)
which does not suffer from exponential scaling in time (particle count explosion),
thanks to the utilization of (partitioned) recombination in the (approximate) cubature
MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098)
which does not suffer from exponential scaling in time (particle count explosion),
thanks to the utilization of (partitioned) recombination in the (approximate) cubature
kernel.

### Example
Expand Down
26 changes: 13 additions & 13 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ $$
where the step size $h=t_{i+1}-t_{i}$ is constant, and each $\Delta W_{i}$ is an idependant sample from a (potentially multi-variate) Gaussian variable with mean zero and diagonal covariance $h$.

### ULA in Diffrax
It is very easy to implement the ULA in Diffrax, as demonstrated in the below example, which generates 512 independant Markov chains by simulating the SDE via the Euler-Maruyama method, performing the standard unadjusted Langevin algorithm for a single initial condition ($Y_0$ is a d-dimensional vector with all elements equal to one).
It is very easy to implement the ULA in Diffrax, as demonstrated in the below example, which generates 512 independant Markov chains by simulating the SDE via the Euler-Maruyama method, performing the standard unadjusted Langevin algorithm for a single initial condition ($Y_0$ is a d-dimensional vector with all elements equal to one).

It is important to note that while the computation of the 512 chains is performed in parallel, this is not a "Parallel MCMC" method as each path is independant (unlike in MCC).

Expand Down Expand Up @@ -85,7 +85,7 @@ solver = diffrax.Euler()
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t0,
t1,
dt0,
y0,
Expand Down Expand Up @@ -113,7 +113,7 @@ evaluate_method(particles, "Diffrax ULA")
### Adjusted Langevin Algorithm in Blackjax
ULA does not strictly obey the [detailed balance](https://en.wikipedia.org/wiki/Detailed_balance) properties required for a unique ergodic stationary distribution to exist, and as such, is unlikely to be used in practice.

A more realistic scenario would be the use of the [Blackjax](https://github.com/blackjax-devs/blackjax) package and one of its more advanced samplers. For example, the [Metropolis-Adjusted Langevin Algorithm (MALA)](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm), see demonstration below, which adjusts the ULA to ensure the detailed balance properties are enforced.
A more realistic scenario would be the use of the [Blackjax](https://github.com/blackjax-devs/blackjax) package and one of its more advanced samplers. For example, the [Metropolis-Adjusted Langevin Algorithm (MALA)](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm), see demonstration below, which adjusts the ULA to ensure the detailed balance properties are enforced.

```python
import blackjax
Expand Down Expand Up @@ -148,7 +148,7 @@ evaluate_method(particles, "Blackjax MALA")
```

## Markov Chain Cubature
Markov chain cubature allows us to take a fundamentally different approach to the problem of solving the Langevin SDE (equivalently obtaining samples from $f$).
Markov chain cubature allows us to take a fundamentally different approach to the problem of solving the Langevin SDE (equivalently obtaining samples from $f$).
Rather than atempting to obtain (potenially $n$) independant pathwise solutions to the SDE, with MCC, one attempts to find a set of $n$ time-evolving dependant particles which at any point in time attempt to weakly solve the SDE (that is solve the SDE in law/distribution).

The crucial difference here is that paths traced by these particles need not coincide with any pathwise solutions of the SDE. The only requirement is that the distribution of these particles be identical to the distribution of all the infinitely many pathwise solutions.
Expand All @@ -172,9 +172,9 @@ Now returning to the example in the [README](/#example) reproduced below:
```python
from mccube import (
GaussianRegion,
Hadamard,
Hadamard,
LocalLinearCubaturePath,
MCCSolver,
MCCSolver,
MCCTerm,
MonteCarloKernel,
PartitioningRecombinationKernel,
Expand All @@ -195,7 +195,7 @@ solver = MCCSolver(diffrax.Euler(), kernel)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t0,
t1,
dt0,
y0,
Expand All @@ -221,7 +221,7 @@ solver = MCCSolver(diffrax.Euler(), kernel)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t0,
t1,
dt0,
y0,
Expand All @@ -235,10 +235,10 @@ evaluate_method(particles, "MCCube ULA | Partitioned MC Kernel")
```

### Weighted MCC
The above examples treat all particles as having equal mass/weight. That is to say, one can consider the particles as representing a discrete measure
The above examples treat all particles as having equal mass/weight. That is to say, one can consider the particles as representing a discrete measure
$$\mu = \sum_{i=1}^n \lambda_i \delta_{x_i},$$
where each $\lambda_i$ is a probability weight/mass and each $x_{i}$ is a particle.
In the above examples, the guassian cubature assigns equal weight to each proposal particle (update path), and the recombination kernels are weight invariant. However, in some cases the gaussian cubature will assign unequall weights, and the recombination kernel will be weight dependant.
In the above examples, the guassian cubature assigns equal weight to each proposal particle (update path), and the recombination kernels are weight invariant. However, in some cases the gaussian cubature will assign unequall weights, and the recombination kernel will be weight dependant.

To utilise these weights in MCCube is relatively simple, requiring only a few minor modifications to the prior example.

Expand All @@ -253,7 +253,7 @@ solver = MCCSolver(diffrax.Euler(), kernel, weighted=True)
sol = diffrax.diffeqsolve(
terms,
solver,
t0,
t0,
t1,
dt0,
y0_weighted,
Expand Down Expand Up @@ -305,9 +305,9 @@ evaluate_method(state.particles, "SVGD")
Like with MCC, the performance in this case is worse than ULA/MALA, again highlighting the importance of selecting appropriate kernels and (for SVGD) optimizers. Interested readers are encouraged to play around with the above examples and to identify parameterisations which yield enhanced performance.

## Next Steps
Equiped with the above knowledge, it should be possible to start experimenting with MCCube.
Equiped with the above knowledge, it should be possible to start experimenting with MCCube.

API documentation can be found [here](api/), and please feel free to submit an issue if there are any tutorials or guides you would like to see added to the documentation.

!!! tip
To get the most out of this package it is helpful to be familiar with all the bells and whistles of [Diffrax](https://github.com/patrick-kidger/diffrax).
To get the most out of this package it is helpful to be familiar with all the bells and whistles of [Diffrax](https://github.com/patrick-kidger/diffrax).
24 changes: 12 additions & 12 deletions mccube/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
"""Defines custom types that are used throughout the package. The following symbols
"""Defines custom types that are used throughout the package. The following symbols
are used in the definitions of the custom types:
- **d**: the dimensionality of the particles.
- **n**: the number of particles.
- **n_hat**: the number of recombined particles.
- **k**: the number of versions of a particles (resulting from the same number of
- **k**: the number of versions of a particles (resulting from the same number of
cubature paths/points).
- **m**: the number of partitions of the particles.
"""

from typing import Any, TYPE_CHECKING

import numpy as np
import numpy.typing as npt
from jaxtyping import Array, ArrayLike, Bool, Float, Int, PyTree, Shaped

# These are identical to the definitions in diffrax.
if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import numpy as np
import numpy.typing as npt

BoolScalarLike = bool | Array | npt.NDArray[np.bool_]
FloatScalarLike = float | Array | npt.NDArray[np.float_]
IntScalarLike = int | Array | npt.NDArray[np.int_]
Expand All @@ -33,16 +32,16 @@
"""A PyTree where each leaf is an array of `n` particles of dimension `d`."""

PartitionedParticles = PyTree[Shaped[Array, "?m ?n_div_m d"], "P"]
"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
reorganised into `m` equally sized partitions of `n/m` particles of dimension `d`."""

RecombinedParticles = PyTree[Shaped[Array, "?n_hat d"], "P"]
"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
recombined/compressed into `n_hat < n` particles of dimension `d`."""

UnpackedParticles = PyTree[Shaped[Array, "?n d-1"], "P"]
"""A [`Particles`][mccube._custom_types.Particles] PyTree of `n` particles of dimension
`d-1`, which have been unpacked from a PyTree of `n` particles of dimension `d`, where
"""A [`Particles`][mccube._custom_types.Particles] PyTree of `n` particles of dimension
`d-1`, which have been unpacked from a PyTree of `n` particles of dimension `d`, where
the `d`-th dimension represents the particle [`Weights`][mccube._custom_types.Weights]."""

PackedParticles = PyTree[Shaped[Array, "?n d+1"], "P"]
Expand All @@ -52,7 +51,8 @@
`n` weights."""

Weights = PyTree[Shaped[Array, "?*n"] | None, "P"]
"""A PyTree where each leaf is an array of `n` [`Weights`][mccube._custom_types.Weights] or [`None`][]."""
"""A PyTree where each leaf is an array of `n` [`Weights`][mccube._custom_types.Weights]
or [`None`][]."""

Args = PyTree[Any]
"""A PyTree of auxillary arguments."""
Expand Down
19 changes: 9 additions & 10 deletions mccube/_formulae.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools
from collections.abc import Callable, Collection
from functools import cached_property
from typing import Generic, Literal, TypeVar
from typing import cast, Generic, Literal, TypeVar

import equinox as eqx
import jax
Expand All @@ -27,6 +27,7 @@
from ._regions import AbstractRegion, GaussianRegion
from ._utils import all_subclasses

ω = cast(Callable, ω)
_Region = TypeVar("_Region", bound=AbstractRegion)


Expand Down Expand Up @@ -77,25 +78,22 @@ def point_count(self):
@abc.abstractmethod
def weights(self) -> CubatureWeightsTree:
r"""A PyTree of Cubature weights $\lambda_j \in \mathbb{R}_+$ for the measure $\mu$."""
...

@cached_property
@abc.abstractmethod
def points(self) -> CubaturePointsTree:
r"""A PyTree of Cubature points $x_j \in \Omega$ for the measure $\mu$."""
...

@cached_property
@abc.abstractmethod
def point_count(self) -> IntScalarLike:
r"""Cubature point count $k$."""
...

@cached_property
def stacked_weights(self) -> CubatureWeights:
"""[`weights`][mccube.AbstractCubature.weights] stacked into a single vector."""
weight_array = jtu.tree_map(lambda x: np.ones(x.shape[0]), self.points)
return np.hstack((ω(self.weights) * ω(weight_array)).ω) # pyright: ignore
return np.hstack((ω(self.weights) * ω(weight_array)).ω)

@cached_property
def stacked_points(self) -> CubaturePoints:
Expand All @@ -111,7 +109,7 @@ def __call__(
Computes the cubature formula $Q[f] = \sum_{j=1}^{k} \lambda_j f(x_j)$.
Args:
integrand: the jax transformable function to integrate.
integrand: the JAX transformable function to integrate.
Returns:
Approximated integral and stacked weighted evaluations of $f$ at each
Expand Down Expand Up @@ -459,9 +457,8 @@ def __check_init__(self):
d = self.region.dimension
minimum_valid_dimension = 3
if d < minimum_valid_dimension:
raise ValueError(
f"StroudSecrest63_53 is only valid for regions with d > 2; got d={d}"
)
msg = f"StroudSecrest63_53 is only valid for regions with d > 2; got d={d}"
raise ValueError(msg)

@cached_property
def weights(self) -> CubatureWeightsTree:
Expand Down Expand Up @@ -540,7 +537,8 @@ def _generate_point_permutations(
ValueError: invalid mode specified.
"""
if mode not in _modes:
raise ValueError(f"Mode must be one of {_modes}, got {mode}.")
msg = f"Mode must be one of {_modes}, got {mode}."
raise ValueError(msg)

_point = np.atleast_1d(point)
point_dim = np.shape(_point)[-1]
Expand Down Expand Up @@ -577,6 +575,7 @@ def _generate_point_permutations(
def search_cubature_registry(
region: AbstractRegion,
degree: int | None = None,
*,
sparse_only: bool = False,
minimal_only: bool = False,
searchable_formulae: Collection[type[AbstractCubature]] = builtin_cubature_registry,
Expand Down
4 changes: 2 additions & 2 deletions mccube/_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
r"""This module provides the tools for constructing classes of Markov transition kernels
which obey certain desired properties, along with a library of useful kernels for Markov
r"""This module provides the tools for constructing classes of Markov transition kernels
which obey certain desired properties, along with a library of useful kernels for Markov
Chain Cubature.
"""

Expand Down
4 changes: 2 additions & 2 deletions mccube/_kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __call__(
A PyTree of transformed particles $p(t)$ with the same PyTree structure and
dimension as the input particles, $p(t_0)$.
"""
...


class AbstractPartitioningKernel(AbstractKernel):
Expand All @@ -65,7 +64,7 @@ class AbstractPartitioningKernel(AbstractKernel):
partition_count: PyTree[int, "Particles"] | None

@override
def __call__(
def __call__( # pragma: no cover
self,
t: RealScalarLike,
particles: Particles,
Expand Down Expand Up @@ -159,6 +158,7 @@ def __call__(
_vmap_recombination_kernel = eqx.filter_vmap(
self.recombination_kernel, in_axes=(None, 0, None, None)
)
# TODO: This needs reconsideration.
_vmap_tree_compatible_recombination_kernels = jtu.tree_map(
lambda c: eqx.tree_at(
lambda k: k._fun.recombination_count, _vmap_recombination_kernel, c
Expand Down
7 changes: 5 additions & 2 deletions mccube/_kernels/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.random as jr
import jax.tree_util as jtu
from diffrax._misc import force_bitcast_convert_type, split_by_tree
from jaxtyping import PRNGKeyArray
from jaxtyping import Array, PRNGKeyArray, Shaped

from .._custom_types import (
Args,
Expand Down Expand Up @@ -53,11 +53,14 @@ def __call__(
args: Args,
weighted: bool = False,
) -> RecombinedParticles | PartitionedParticles:
del args
_t = force_bitcast_convert_type(t, jnp.int32)
key = jr.fold_in(self.key, _t)
keys = split_by_tree(key, particles)

def _choice(key, p, count):
def _choice(
key: PRNGKeyArray, p: Shaped[Array, "n d"], count: int | tuple[int, ...]
) -> Shaped[Array, "n_hat d"] | Shaped[Array, "m n_div_m d"]:
_, weights = unpack_particles(p, weighted)
if weighted:
weights = self.weighting_function(weights)
Expand Down
8 changes: 6 additions & 2 deletions mccube/_kernels/stratified.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import ArrayLike, Shaped
from jaxtyping import Array, ArrayLike, Shaped

from .._custom_types import Args, Particles, PartitionedParticles, RealScalarLike
from .._metrics import center_of_mass
Expand Down Expand Up @@ -33,7 +33,11 @@ def __call__(
args: Args,
weighted: bool = False,
) -> PartitionedParticles:
def _stratified_partitioning(_p, _count):
del t, args

def _stratified_partitioning(
_p: Shaped[Array, "n d"], _count: int
) -> Shaped[Array, "m n_div_m d"]:
if self.norm is None:
return _p.reshape(_count, -1, _p.shape[-1])
_p_unpacked, _w = unpack_particles(_p, weighted)
Expand Down
26 changes: 20 additions & 6 deletions mccube/_kernels/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import Shaped
from optimistix._solver.nelder_mead import ArrayLike
from sklearn.metrics import DistanceMetric
from sklearn.neighbors import BallTree, KDTree

from .._custom_types import Args, Particles, PartitionedParticles, RealScalarLike
from .._custom_types import (
Args,
IntScalarLike,
Particles,
PartitionedParticles,
RealScalarLike,
)
from .._utils import unpack_particles
from .base import AbstractPartitioningKernel

Expand Down Expand Up @@ -41,20 +49,26 @@ def __call__(
args: Args,
weighted: bool = False,
) -> PartitionedParticles:
def _tree_fn(p, leaf_size):
del t, args

def _tree_fn(
p: Shaped[ArrayLike, "n d"], leaf_size: IntScalarLike
) -> Shaped[ArrayLike, "n_div_m idx"]:
t = trees[self.mode](p, leaf_size, self.metric, **self.metric_kwargs)
indices = t.get_arrays()[1].reshape(-1, leaf_size)
return indices.astype(np.int32)

def _tree_partitioning(p, count):
leaf_size = p.shape[0] // count
def _tree_partitioning(
p: Shaped[ArrayLike, "n d"], count: IntScalarLike
) -> Shaped[ArrayLike, "m n_div_m d"]:
leaf_size = jnp.shape(p)[0] // count
shape = (count, leaf_size)
dtype = jnp.int32
result_shape_dtype = jax.ShapeDtypeStruct(shape, dtype)
p_unpacked, _ = unpack_particles(p, weighted)
indices = jax.pure_callback( # type: ignore
indices = jax.pure_callback( # pyright: ignore[reportPrivateImportUsage]
_tree_fn, result_shape_dtype, p_unpacked, leaf_size
)
return p[indices]
return jnp.asarray(p)[indices]

return jtu.tree_map(_tree_partitioning, particles, self.partition_count)
Loading

0 comments on commit 525480a

Please sign in to comment.