From 525480aa0b6535e8964ac03d42fc814905232518 Mon Sep 17 00:00:00 2001
From: tttc3 <97948946+tttc3@users.noreply.github.com>
Date: Sat, 24 Feb 2024 15:01:19 +0000
Subject: [PATCH] Updated linting rules and ensured compliance
---
README.md | 10 +++++-----
docs/quickstart.md | 26 ++++++++++++-------------
mccube/_custom_types.py | 24 +++++++++++------------
mccube/_formulae.py | 19 +++++++++---------
mccube/_kernels/__init__.py | 4 ++--
mccube/_kernels/base.py | 4 ++--
mccube/_kernels/random.py | 7 +++++--
mccube/_kernels/stratified.py | 8 ++++++--
mccube/_kernels/tree.py | 26 +++++++++++++++++++------
mccube/_metrics.py | 6 +++++-
mccube/_path.py | 2 +-
mccube/_regions.py | 8 ++++----
mccube/_solvers.py | 14 ++++++++------
mccube/_term.py | 11 ++++++++---
mccube/_utils.py | 6 ++++--
mkdocs.yaml | 3 +--
pyproject.toml | 36 ++++++++++++++++++-----------------
tests/conftest.py | 4 ++--
tests/test_formulae.py | 28 ++++++++++++++++++---------
tests/test_kernels.py | 13 +++++++------
tests/test_metrics.py | 1 -
tests/test_path.py | 15 ++++++++-------
tests/test_solvers.py | 14 +++++++++-----
tests/test_term.py | 4 ++++
tests/test_utils.py | 8 ++++++++
25 files changed, 181 insertions(+), 120 deletions(-)
diff --git a/README.md b/README.md
index 473bdd3..3ef0a16 100644
--- a/README.md
+++ b/README.md
@@ -25,18 +25,18 @@ pip install mccube
```
Requires Python 3.9+, Diffrax 0.5.0+, and Equinox 0.11.3+.
-By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA
+By default, a CPU only version of JAX will be installed. To make use of other JAX/XLA
compatible accelerators (GPUs/TPUs) please follow [these installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier).
-Windows support for JAX is currently experimental; WSL2 is the recommended approach for
+Windows support for JAX is currently experimental; WSL2 is the recommended approach for
using JAX on Windows.
## Documentation
Available at [https://mccube.readthedocs.io/](https://mccube.readthedocs.io/).
## What is Markov chain cubature?
-MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098)
-which does not suffer from exponential scaling in time (particle count explosion),
-thanks to the utilization of (partitioned) recombination in the (approximate) cubature
+MCC is an approach to constructing a [Cubature on Wiener Space](https://www.jstor.org/stable/4143098)
+which does not suffer from exponential scaling in time (particle count explosion),
+thanks to the utilization of (partitioned) recombination in the (approximate) cubature
kernel.
### Example
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 884254e..c9e81a9 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -39,7 +39,7 @@ $$
where the step size $h=t_{i+1}-t_{i}$ is constant, and each $\Delta W_{i}$ is an idependant sample from a (potentially multi-variate) Gaussian variable with mean zero and diagonal covariance $h$.
### ULA in Diffrax
-It is very easy to implement the ULA in Diffrax, as demonstrated in the below example, which generates 512 independant Markov chains by simulating the SDE via the Euler-Maruyama method, performing the standard unadjusted Langevin algorithm for a single initial condition ($Y_0$ is a d-dimensional vector with all elements equal to one).
+It is very easy to implement the ULA in Diffrax, as demonstrated in the below example, which generates 512 independant Markov chains by simulating the SDE via the Euler-Maruyama method, performing the standard unadjusted Langevin algorithm for a single initial condition ($Y_0$ is a d-dimensional vector with all elements equal to one).
It is important to note that while the computation of the 512 chains is performed in parallel, this is not a "Parallel MCMC" method as each path is independant (unlike in MCC).
@@ -85,7 +85,7 @@ solver = diffrax.Euler()
sol = diffrax.diffeqsolve(
terms,
solver,
- t0,
+ t0,
t1,
dt0,
y0,
@@ -113,7 +113,7 @@ evaluate_method(particles, "Diffrax ULA")
### Adjusted Langevin Algorithm in Blackjax
ULA does not strictly obey the [detailed balance](https://en.wikipedia.org/wiki/Detailed_balance) properties required for a unique ergodic stationary distribution to exist, and as such, is unlikely to be used in practice.
-A more realistic scenario would be the use of the [Blackjax](https://github.com/blackjax-devs/blackjax) package and one of its more advanced samplers. For example, the [Metropolis-Adjusted Langevin Algorithm (MALA)](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm), see demonstration below, which adjusts the ULA to ensure the detailed balance properties are enforced.
+A more realistic scenario would be the use of the [Blackjax](https://github.com/blackjax-devs/blackjax) package and one of its more advanced samplers. For example, the [Metropolis-Adjusted Langevin Algorithm (MALA)](https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm), see demonstration below, which adjusts the ULA to ensure the detailed balance properties are enforced.
```python
import blackjax
@@ -148,7 +148,7 @@ evaluate_method(particles, "Blackjax MALA")
```
## Markov Chain Cubature
-Markov chain cubature allows us to take a fundamentally different approach to the problem of solving the Langevin SDE (equivalently obtaining samples from $f$).
+Markov chain cubature allows us to take a fundamentally different approach to the problem of solving the Langevin SDE (equivalently obtaining samples from $f$).
Rather than atempting to obtain (potenially $n$) independant pathwise solutions to the SDE, with MCC, one attempts to find a set of $n$ time-evolving dependant particles which at any point in time attempt to weakly solve the SDE (that is solve the SDE in law/distribution).
The crucial difference here is that paths traced by these particles need not coincide with any pathwise solutions of the SDE. The only requirement is that the distribution of these particles be identical to the distribution of all the infinitely many pathwise solutions.
@@ -172,9 +172,9 @@ Now returning to the example in the [README](/#example) reproduced below:
```python
from mccube import (
GaussianRegion,
- Hadamard,
+ Hadamard,
LocalLinearCubaturePath,
- MCCSolver,
+ MCCSolver,
MCCTerm,
MonteCarloKernel,
PartitioningRecombinationKernel,
@@ -195,7 +195,7 @@ solver = MCCSolver(diffrax.Euler(), kernel)
sol = diffrax.diffeqsolve(
terms,
solver,
- t0,
+ t0,
t1,
dt0,
y0,
@@ -221,7 +221,7 @@ solver = MCCSolver(diffrax.Euler(), kernel)
sol = diffrax.diffeqsolve(
terms,
solver,
- t0,
+ t0,
t1,
dt0,
y0,
@@ -235,10 +235,10 @@ evaluate_method(particles, "MCCube ULA | Partitioned MC Kernel")
```
### Weighted MCC
-The above examples treat all particles as having equal mass/weight. That is to say, one can consider the particles as representing a discrete measure
+The above examples treat all particles as having equal mass/weight. That is to say, one can consider the particles as representing a discrete measure
$$\mu = \sum_{i=1}^n \lambda_i \delta_{x_i},$$
where each $\lambda_i$ is a probability weight/mass and each $x_{i}$ is a particle.
-In the above examples, the guassian cubature assigns equal weight to each proposal particle (update path), and the recombination kernels are weight invariant. However, in some cases the gaussian cubature will assign unequall weights, and the recombination kernel will be weight dependant.
+In the above examples, the guassian cubature assigns equal weight to each proposal particle (update path), and the recombination kernels are weight invariant. However, in some cases the gaussian cubature will assign unequall weights, and the recombination kernel will be weight dependant.
To utilise these weights in MCCube is relatively simple, requiring only a few minor modifications to the prior example.
@@ -253,7 +253,7 @@ solver = MCCSolver(diffrax.Euler(), kernel, weighted=True)
sol = diffrax.diffeqsolve(
terms,
solver,
- t0,
+ t0,
t1,
dt0,
y0_weighted,
@@ -305,9 +305,9 @@ evaluate_method(state.particles, "SVGD")
Like with MCC, the performance in this case is worse than ULA/MALA, again highlighting the importance of selecting appropriate kernels and (for SVGD) optimizers. Interested readers are encouraged to play around with the above examples and to identify parameterisations which yield enhanced performance.
## Next Steps
-Equiped with the above knowledge, it should be possible to start experimenting with MCCube.
+Equiped with the above knowledge, it should be possible to start experimenting with MCCube.
API documentation can be found [here](api/), and please feel free to submit an issue if there are any tutorials or guides you would like to see added to the documentation.
!!! tip
- To get the most out of this package it is helpful to be familiar with all the bells and whistles of [Diffrax](https://github.com/patrick-kidger/diffrax).
+ To get the most out of this package it is helpful to be familiar with all the bells and whistles of [Diffrax](https://github.com/patrick-kidger/diffrax).
diff --git a/mccube/_custom_types.py b/mccube/_custom_types.py
index 436ae84..0d4f7cd 100644
--- a/mccube/_custom_types.py
+++ b/mccube/_custom_types.py
@@ -1,22 +1,21 @@
-"""Defines custom types that are used throughout the package. The following symbols
+"""Defines custom types that are used throughout the package. The following symbols
are used in the definitions of the custom types:
- **d**: the dimensionality of the particles.
- **n**: the number of particles.
- **n_hat**: the number of recombined particles.
-- **k**: the number of versions of a particles (resulting from the same number of
+- **k**: the number of versions of a particles (resulting from the same number of
cubature paths/points).
- **m**: the number of partitions of the particles.
"""
-
from typing import Any, TYPE_CHECKING
-import numpy as np
-import numpy.typing as npt
from jaxtyping import Array, ArrayLike, Bool, Float, Int, PyTree, Shaped
-# These are identical to the definitions in diffrax.
-if TYPE_CHECKING:
+if TYPE_CHECKING: # pragma: no cover
+ import numpy as np
+ import numpy.typing as npt
+
BoolScalarLike = bool | Array | npt.NDArray[np.bool_]
FloatScalarLike = float | Array | npt.NDArray[np.float_]
IntScalarLike = int | Array | npt.NDArray[np.int_]
@@ -33,16 +32,16 @@
"""A PyTree where each leaf is an array of `n` particles of dimension `d`."""
PartitionedParticles = PyTree[Shaped[Array, "?m ?n_div_m d"], "P"]
-"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
+"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
reorganised into `m` equally sized partitions of `n/m` particles of dimension `d`."""
RecombinedParticles = PyTree[Shaped[Array, "?n_hat d"], "P"]
-"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
+"""A [`Particles`][mccube._custom_types.Particles] PyTree where each leaf has been
recombined/compressed into `n_hat < n` particles of dimension `d`."""
UnpackedParticles = PyTree[Shaped[Array, "?n d-1"], "P"]
-"""A [`Particles`][mccube._custom_types.Particles] PyTree of `n` particles of dimension
-`d-1`, which have been unpacked from a PyTree of `n` particles of dimension `d`, where
+"""A [`Particles`][mccube._custom_types.Particles] PyTree of `n` particles of dimension
+`d-1`, which have been unpacked from a PyTree of `n` particles of dimension `d`, where
the `d`-th dimension represents the particle [`Weights`][mccube._custom_types.Weights]."""
PackedParticles = PyTree[Shaped[Array, "?n d+1"], "P"]
@@ -52,7 +51,8 @@
`n` weights."""
Weights = PyTree[Shaped[Array, "?*n"] | None, "P"]
-"""A PyTree where each leaf is an array of `n` [`Weights`][mccube._custom_types.Weights] or [`None`][]."""
+"""A PyTree where each leaf is an array of `n` [`Weights`][mccube._custom_types.Weights]
+or [`None`][]."""
Args = PyTree[Any]
"""A PyTree of auxillary arguments."""
diff --git a/mccube/_formulae.py b/mccube/_formulae.py
index 542ef51..7571d92 100644
--- a/mccube/_formulae.py
+++ b/mccube/_formulae.py
@@ -5,7 +5,7 @@
import itertools
from collections.abc import Callable, Collection
from functools import cached_property
-from typing import Generic, Literal, TypeVar
+from typing import cast, Generic, Literal, TypeVar
import equinox as eqx
import jax
@@ -27,6 +27,7 @@
from ._regions import AbstractRegion, GaussianRegion
from ._utils import all_subclasses
+ω = cast(Callable, ω)
_Region = TypeVar("_Region", bound=AbstractRegion)
@@ -77,25 +78,22 @@ def point_count(self):
@abc.abstractmethod
def weights(self) -> CubatureWeightsTree:
r"""A PyTree of Cubature weights $\lambda_j \in \mathbb{R}_+$ for the measure $\mu$."""
- ...
@cached_property
@abc.abstractmethod
def points(self) -> CubaturePointsTree:
r"""A PyTree of Cubature points $x_j \in \Omega$ for the measure $\mu$."""
- ...
@cached_property
@abc.abstractmethod
def point_count(self) -> IntScalarLike:
r"""Cubature point count $k$."""
- ...
@cached_property
def stacked_weights(self) -> CubatureWeights:
"""[`weights`][mccube.AbstractCubature.weights] stacked into a single vector."""
weight_array = jtu.tree_map(lambda x: np.ones(x.shape[0]), self.points)
- return np.hstack((ω(self.weights) * ω(weight_array)).ω) # pyright: ignore
+ return np.hstack((ω(self.weights) * ω(weight_array)).ω)
@cached_property
def stacked_points(self) -> CubaturePoints:
@@ -111,7 +109,7 @@ def __call__(
Computes the cubature formula $Q[f] = \sum_{j=1}^{k} \lambda_j f(x_j)$.
Args:
- integrand: the jax transformable function to integrate.
+ integrand: the JAX transformable function to integrate.
Returns:
Approximated integral and stacked weighted evaluations of $f$ at each
@@ -459,9 +457,8 @@ def __check_init__(self):
d = self.region.dimension
minimum_valid_dimension = 3
if d < minimum_valid_dimension:
- raise ValueError(
- f"StroudSecrest63_53 is only valid for regions with d > 2; got d={d}"
- )
+ msg = f"StroudSecrest63_53 is only valid for regions with d > 2; got d={d}"
+ raise ValueError(msg)
@cached_property
def weights(self) -> CubatureWeightsTree:
@@ -540,7 +537,8 @@ def _generate_point_permutations(
ValueError: invalid mode specified.
"""
if mode not in _modes:
- raise ValueError(f"Mode must be one of {_modes}, got {mode}.")
+ msg = f"Mode must be one of {_modes}, got {mode}."
+ raise ValueError(msg)
_point = np.atleast_1d(point)
point_dim = np.shape(_point)[-1]
@@ -577,6 +575,7 @@ def _generate_point_permutations(
def search_cubature_registry(
region: AbstractRegion,
degree: int | None = None,
+ *,
sparse_only: bool = False,
minimal_only: bool = False,
searchable_formulae: Collection[type[AbstractCubature]] = builtin_cubature_registry,
diff --git a/mccube/_kernels/__init__.py b/mccube/_kernels/__init__.py
index 90c0ddd..0b6bcff 100644
--- a/mccube/_kernels/__init__.py
+++ b/mccube/_kernels/__init__.py
@@ -1,5 +1,5 @@
-r"""This module provides the tools for constructing classes of Markov transition kernels
-which obey certain desired properties, along with a library of useful kernels for Markov
+r"""This module provides the tools for constructing classes of Markov transition kernels
+which obey certain desired properties, along with a library of useful kernels for Markov
Chain Cubature.
"""
diff --git a/mccube/_kernels/base.py b/mccube/_kernels/base.py
index 7845710..c00d4c6 100644
--- a/mccube/_kernels/base.py
+++ b/mccube/_kernels/base.py
@@ -48,7 +48,6 @@ def __call__(
A PyTree of transformed particles $p(t)$ with the same PyTree structure and
dimension as the input particles, $p(t_0)$.
"""
- ...
class AbstractPartitioningKernel(AbstractKernel):
@@ -65,7 +64,7 @@ class AbstractPartitioningKernel(AbstractKernel):
partition_count: PyTree[int, "Particles"] | None
@override
- def __call__(
+ def __call__( # pragma: no cover
self,
t: RealScalarLike,
particles: Particles,
@@ -159,6 +158,7 @@ def __call__(
_vmap_recombination_kernel = eqx.filter_vmap(
self.recombination_kernel, in_axes=(None, 0, None, None)
)
+ # TODO: This needs reconsideration.
_vmap_tree_compatible_recombination_kernels = jtu.tree_map(
lambda c: eqx.tree_at(
lambda k: k._fun.recombination_count, _vmap_recombination_kernel, c
diff --git a/mccube/_kernels/random.py b/mccube/_kernels/random.py
index 52cb6c4..d37904c 100644
--- a/mccube/_kernels/random.py
+++ b/mccube/_kernels/random.py
@@ -5,7 +5,7 @@
import jax.random as jr
import jax.tree_util as jtu
from diffrax._misc import force_bitcast_convert_type, split_by_tree
-from jaxtyping import PRNGKeyArray
+from jaxtyping import Array, PRNGKeyArray, Shaped
from .._custom_types import (
Args,
@@ -53,11 +53,14 @@ def __call__(
args: Args,
weighted: bool = False,
) -> RecombinedParticles | PartitionedParticles:
+ del args
_t = force_bitcast_convert_type(t, jnp.int32)
key = jr.fold_in(self.key, _t)
keys = split_by_tree(key, particles)
- def _choice(key, p, count):
+ def _choice(
+ key: PRNGKeyArray, p: Shaped[Array, "n d"], count: int | tuple[int, ...]
+ ) -> Shaped[Array, "n_hat d"] | Shaped[Array, "m n_div_m d"]:
_, weights = unpack_particles(p, weighted)
if weighted:
weights = self.weighting_function(weights)
diff --git a/mccube/_kernels/stratified.py b/mccube/_kernels/stratified.py
index 72c9eec..ec0fbb5 100644
--- a/mccube/_kernels/stratified.py
+++ b/mccube/_kernels/stratified.py
@@ -2,7 +2,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
-from jaxtyping import ArrayLike, Shaped
+from jaxtyping import Array, ArrayLike, Shaped
from .._custom_types import Args, Particles, PartitionedParticles, RealScalarLike
from .._metrics import center_of_mass
@@ -33,7 +33,11 @@ def __call__(
args: Args,
weighted: bool = False,
) -> PartitionedParticles:
- def _stratified_partitioning(_p, _count):
+ del t, args
+
+ def _stratified_partitioning(
+ _p: Shaped[Array, "n d"], _count: int
+ ) -> Shaped[Array, "m n_div_m d"]:
if self.norm is None:
return _p.reshape(_count, -1, _p.shape[-1])
_p_unpacked, _w = unpack_particles(_p, weighted)
diff --git a/mccube/_kernels/tree.py b/mccube/_kernels/tree.py
index 418712a..7a33aec 100644
--- a/mccube/_kernels/tree.py
+++ b/mccube/_kernels/tree.py
@@ -5,10 +5,18 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
+from jaxtyping import Shaped
+from optimistix._solver.nelder_mead import ArrayLike
from sklearn.metrics import DistanceMetric
from sklearn.neighbors import BallTree, KDTree
-from .._custom_types import Args, Particles, PartitionedParticles, RealScalarLike
+from .._custom_types import (
+ Args,
+ IntScalarLike,
+ Particles,
+ PartitionedParticles,
+ RealScalarLike,
+)
from .._utils import unpack_particles
from .base import AbstractPartitioningKernel
@@ -41,20 +49,26 @@ def __call__(
args: Args,
weighted: bool = False,
) -> PartitionedParticles:
- def _tree_fn(p, leaf_size):
+ del t, args
+
+ def _tree_fn(
+ p: Shaped[ArrayLike, "n d"], leaf_size: IntScalarLike
+ ) -> Shaped[ArrayLike, "n_div_m idx"]:
t = trees[self.mode](p, leaf_size, self.metric, **self.metric_kwargs)
indices = t.get_arrays()[1].reshape(-1, leaf_size)
return indices.astype(np.int32)
- def _tree_partitioning(p, count):
- leaf_size = p.shape[0] // count
+ def _tree_partitioning(
+ p: Shaped[ArrayLike, "n d"], count: IntScalarLike
+ ) -> Shaped[ArrayLike, "m n_div_m d"]:
+ leaf_size = jnp.shape(p)[0] // count
shape = (count, leaf_size)
dtype = jnp.int32
result_shape_dtype = jax.ShapeDtypeStruct(shape, dtype)
p_unpacked, _ = unpack_particles(p, weighted)
- indices = jax.pure_callback( # type: ignore
+ indices = jax.pure_callback( # pyright: ignore[reportPrivateImportUsage]
_tree_fn, result_shape_dtype, p_unpacked, leaf_size
)
- return p[indices]
+ return jnp.asarray(p)[indices]
return jtu.tree_map(_tree_partitioning, particles, self.partition_count)
diff --git a/mccube/_metrics.py b/mccube/_metrics.py
index 87febbc..d56c32e 100644
--- a/mccube/_metrics.py
+++ b/mccube/_metrics.py
@@ -1,4 +1,6 @@
"""Defines helpful metrics and dissimilarity measures."""
+from collections.abc import Callable
+
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
@@ -61,7 +63,9 @@ def squared_euclidean_metric(
def pairwise_metric(
- xs: Particles, ys: Particles, metric=euclidean_metric
+ xs: Particles,
+ ys: Particles,
+ metric: Callable[[ArrayLike, ArrayLike], RealScalarLike] = euclidean_metric,
) -> PyTree[Shaped[ArrayLike, "?n ?n"], "Particles"]:
"""Pairwise metric between two PyTrees of `n` vectors of dimension `d`.
diff --git a/mccube/_path.py b/mccube/_path.py
index d781013..19df923 100644
--- a/mccube/_path.py
+++ b/mccube/_path.py
@@ -71,6 +71,6 @@ def evaluate(
return points
@property
- def weights(self) -> CubatureWeights: # type: ignore
+ def weights(self) -> CubatureWeights:
"""Vector of cubature weights associated with the cubature control paths."""
return self.gaussian_cubature.stacked_weights
diff --git a/mccube/_regions.py b/mccube/_regions.py
index e547f47..2a33f54 100644
--- a/mccube/_regions.py
+++ b/mccube/_regions.py
@@ -1,5 +1,5 @@
-"""Defines the integration regions (measure spaces) against which [`AbstractCubatures`][mccube.AbstractCubature]
-can be defined."""
+"""Defines the integration regions (measure spaces) against which
+[`AbstractCubatures`][mccube.AbstractCubature] can be defined."""
import abc
@@ -19,10 +19,10 @@ class AbstractRegion(eqx.Module):
dimension: int
- @abc.abstractproperty
+ @property
+ @abc.abstractmethod
def volume(self) -> float:
r"""Measure $\mu$ of the entirety of $\Omega$, denoted by $V$."""
- ...
class GaussianRegion(AbstractRegion):
diff --git a/mccube/_solvers.py b/mccube/_solvers.py
index 1a8d39f..fd41acf 100644
--- a/mccube/_solvers.py
+++ b/mccube/_solvers.py
@@ -83,7 +83,7 @@ class MCCSolver(AbstractWrappedSolver[_SolverState]):
the number of cubature vectors/paths.**
"""
- solver: AbstractSolver[_SolverState] # type: ignore
+ solver: AbstractSolver[_SolverState]
recombination_kernel: AbstractRecombinationKernel
n_substeps: int = 1
weighted: bool = False
@@ -100,7 +100,8 @@ def __check_init__(self):
stacklevel=1,
)
if self.n_substeps < 1:
- raise ValueError(f"n_substeps must be at least one; got {self.n_substeps}")
+ msg = f"n_substeps must be at least one; got {self.n_substeps}"
+ raise ValueError(msg)
def init(
self,
@@ -109,7 +110,7 @@ def init(
t1: RealScalarLike,
y0: Particles,
args: Args,
- ):
+ ) -> _SolverState:
return self.solver.init(terms, t0, t1, y0, args)
def step(
@@ -128,7 +129,8 @@ def step(
_y0 = _y0[:, None, :]
_t0 = t0
_t1 = t0 + dt_substep
-
+ _sol = tuple()
+ # Shape of `y` changes on each itteration, so can't use JAX loop primitives.
for _ in range(self.n_substeps):
_t0 = _t1
_t1 = _t0 + dt_substep
@@ -139,13 +141,13 @@ def step(
cde_cubature_weights = terms.term.cde.control.weights[..., None]
weights = weights[None, ...] * cde_cubature_weights
weights = weights.reshape(-1)
- y1 = _sol[0] # type: ignore
+ y1 = _sol[0]
y1_packed = pack_particles(y1, weights)
y1_hat = self.recombination_kernel(t0, y1_packed, args, self.weighted)
# Used to renormalize the weights post recombination.
y1_res = pack_particles(*unpack_particles(y1_hat, weighted=self.weighted))
dense_info = {"y0": y0, "y1": y1_hat}
- return (y1_res, _sol[1], dense_info, *_sol[3:]) # type: ignore
+ return (y1_res, _sol[1], dense_info) + _sol[3:]
def func(
self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Particles, args: Args
diff --git a/mccube/_term.py b/mccube/_term.py
index 4629c5f..36985d3 100644
--- a/mccube/_term.py
+++ b/mccube/_term.py
@@ -2,11 +2,14 @@
See [`diffrax.AbstractTerm`][] for further information on the terms API.
"""
+from collections.abc import Callable
+from typing import cast
+
import equinox as eqx
import jax.tree_util as jtu
from diffrax import AbstractTerm, ODETerm, WeaklyDiagonalControlTerm
from equinox.internal import ω
-from jaxtyping import ArrayLike, PyTree
+from jaxtyping import Array, ArrayLike, PyTree
from ._custom_types import (
Args,
@@ -15,8 +18,10 @@
RealScalarLike,
)
+ω = cast(Callable, ω)
+
-def _tree_flatten(tree: PyTree[ArrayLike]):
+def _tree_flatten(tree: PyTree[Array]) -> PyTree[Array]:
return jtu.tree_map(
lambda x: x.reshape(-1, x.shape[-1]), tree, is_leaf=eqx.is_array
)
@@ -74,4 +79,4 @@ def prod(
) -> PyTree[ArrayLike, "Particles"]:
ode_prod = self.ode.prod(vf[0], control[0])
cde_prod = self.cde.prod(vf[1], control[1])
- return (ω(ode_prod)[:, None, ...] + ω(cde_prod)[None, ...]).ω # type: ignore
+ return (ω(ode_prod)[:, None, ...] + ω(cde_prod)[None, ...]).ω
diff --git a/mccube/_utils.py b/mccube/_utils.py
index 81bef9a..3fba0d5 100644
--- a/mccube/_utils.py
+++ b/mccube/_utils.py
@@ -5,7 +5,7 @@
from ._custom_types import PackedParticles, Particles, UnpackedParticles, Weights
-def nop(*args, **kwargs) -> None:
+def nop(*args, **kwargs) -> None: # noqa: ANN002, ANN003
"""Callable which accepts any arguments and does nothing.
Example:
@@ -14,6 +14,7 @@ def nop(*args, **kwargs) -> None:
# None
```
"""
+ del args, kwargs
def pack_particles(
@@ -102,4 +103,5 @@ def all_subclasses(cls: type) -> set[type]:
def requires_weighing(is_weighted: bool) -> None:
if not is_weighted:
- raise ValueError("Kernel requires `weighted=True`; got {`weighted=False`}.")
+ msg = "Kernel requires `weighted=True`; got {`weighted=False`}."
+ raise ValueError(msg)
diff --git a/mkdocs.yaml b/mkdocs.yaml
index 7c10841..204cdf4 100644
--- a/mkdocs.yaml
+++ b/mkdocs.yaml
@@ -44,7 +44,7 @@ extra:
- icon: fontawesome/solid/person-running
link: https://datasig.ac.uk
-extra_javascript:
+extra_javascript:
- _static/mathjax.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
@@ -122,4 +122,3 @@ nav:
-
mccube._utils: "api/_utils.md"
-
mccube._metrics: "api/_metrics.md"
- Contributing: "CONTRIBUTING.md"
-
diff --git a/pyproject.toml b/pyproject.toml
index cd4fe62..1f2f643 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,8 +3,8 @@ name = "mccube"
version = "0.0.3"
description = "Markov chain cubature via JAX."
readme = "README.md"
-license = {file = "LICENSE"}
-authors = [{name = "The MCCube team", email = "T.Coxon2@lboro.ac.uk"}]
+license = { file = "LICENSE" }
+authors = [{ name = "The MCCube team", email = "T.Coxon2@lboro.ac.uk" }]
keywords = [
"sampling",
"probability",
@@ -24,14 +24,14 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Mathematics",
]
-requires-python = ">=3.9"
+requires-python = ">=3.10"
dependencies = [
"diffrax>=0.5.0",
"equinox>=0.11.2",
"jax>=0.4.23",
"jaxtyping>=0.2.25",
"scikit-learn>=1.3.2",
- "typing_extensions>=4.9.0"
+ "typing_extensions>=4.9.0",
]
[project.urls]
@@ -40,16 +40,9 @@ documentation = "https://mccube.readthedocs.io"
repository = "https://github.com/tttc3/mccube"
[project.optional-dependencies]
-test = [
- "beartype",
- "pytest",
- "jaxlib"
-]
+test = ["beartype", "pytest", "pytest-cov", "jaxlib"]
-dev = [
- "mccube[test]",
- "pre-commit",
-]
+dev = ["mccube[test]", "pre-commit"]
[build-system]
requires = ["hatchling"]
@@ -59,7 +52,10 @@ build-backend = "hatchling.build"
include = ["mccube/*"]
[tool.pytest.ini_options]
-addopts = "--jaxtyping-packages=mccube,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On))"
+addopts = """
+ --jaxtyping-packages=mccube,beartype.beartype(conf=beartype.BeartypeConf(strategy=beartype.BeartypeStrategy.On)) \
+ --cov=mccube --cov-report term-missing
+ """
[tool.jupytext]
formats = "ipynb, md"
@@ -69,8 +65,11 @@ extend-include = ["*.ipynb"]
[tool.ruff.lint]
fixable = ["I001", "F401"]
-ignore = ["F722"]
-select = ["I001", "PT"]
+ignore = ["ANN101", "ANN204", "ARG005", "F722", "E501"]
+select = ["ANN", "ARG", "B", "E", "EM", "F", "S", "I001", "PT", "W"]
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*" = ["ANN", "S101"]
[tool.ruff.lint.isort]
combine-as-imports = true
@@ -78,5 +77,8 @@ extra-standard-library = ["typing_extensions"]
order-by-type = false
[tool.pyright]
-typeCheckingMode = "standard"
+reportIncompatibleVariableOverride = false # Incompatible with eqx.AbstractVar
+reportUnnecessaryTypeIgnoreComment = true
include = ["mccube", "tests"]
+venvPath = "/Users/tc/miniforge3/envs/"
+venv = "/Users/tc/miniforge3/envs/mccube"
diff --git a/tests/conftest.py b/tests/conftest.py
index 14a4e80..ac63b78 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,3 +1,3 @@
-import jax.config
+import jax
-jax.config.update("jax_enable_x64", True) # type: ignore
+jax.config.update("jax_enable_x64", True)
diff --git a/tests/test_formulae.py b/tests/test_formulae.py
index 3efbfe0..dc326e8 100644
--- a/tests/test_formulae.py
+++ b/tests/test_formulae.py
@@ -1,5 +1,6 @@
import itertools
-from collections.abc import Iterable
+from collections.abc import Callable, Iterable
+from typing import cast
import equinox as eqx
import jax
@@ -16,6 +17,8 @@
builtin_cubature_registry,
)
+ω = cast(Callable, ω)
+
def _base_formulae_tests(f):
# Check points, weights, and point_count are consistent.
@@ -26,7 +29,7 @@ def _base_formulae_tests(f):
assert f.stacked_points.shape[0] == f.stacked_weights.shape[0]
stacked_points = np.vstack(f.points)
weight_array = jtu.tree_map(lambda x: np.ones(x.shape[0]), f.points)
- stacked_weights = np.hstack((ω(f.weights) * ω(weight_array)).ω) # pyright: ignore
+ stacked_weights = np.hstack((ω(f.weights) * ω(weight_array)).ω)
assert eqx.tree_equal(f.stacked_points, stacked_points)
assert eqx.tree_equal(f.stacked_weights, stacked_weights)
@@ -81,14 +84,20 @@ def test_search_cubature_registry(degree, sparse_only, minimal_only, expected_re
}
region = mccube.GaussianRegion(5)
result = mccube.search_cubature_registry(
- region, degree, sparse_only, minimal_only, test_registry
+ region,
+ degree,
+ sparse_only=sparse_only,
+ minimal_only=minimal_only,
+ searchable_formulae=test_registry,
)
_expected_result = jtu.tree_map(lambda f: f(region), expected_result)
assert eqx.tree_equal(result, _expected_result)
def test_search_cubature_registry_default():
- mccube.search_cubature_registry(mccube.GaussianRegion(5), 3, False, True)
+ mccube.search_cubature_registry(
+ mccube.GaussianRegion(5), 3, sparse_only=False, minimal_only=True
+ )
def test_points_permutations():
@@ -153,9 +162,10 @@ def test_gaussian_cubature(formula, degree, test_region_dims):
for monomial in _monomial_generator(f.region.dimension, degree):
coeffs = np.zeros((degree + 1, dim))
coeffs[monomial, np.arange(0, dim)] = 1
- integrand = lambda x: np.prod( # noqa: E731
- polyval(x, coeffs, tensor=False)
- )
+
+ def integrand(x, coeffs=coeffs):
+ return np.prod(polyval(x, coeffs, tensor=False))
+
test_integral = _generalized_hermite_monomial_integral(monomial)
trial_integral = f(integrand)[0]
assert eqx.tree_equal(test_integral, trial_integral, rtol=1e-5, atol=1e-8)
@@ -180,12 +190,12 @@ def _monomial_generator(
def _generalized_hermite_monomial_integral(
monomial: tuple[int, ...], alpha: float = 1 / 2, normalized: bool = True
):
- if any((monomial**ω % 2).ω):
+ if any((ω(monomial) % 2).ω):
return jax.numpy.array(0.0)
normalization_constant = 1.0
if normalized:
dim = len(monomial)
normalization_constant = (alpha ** (-1 / 2) * gamma(1 / 2)) ** dim
- exponent = ((monomial**ω + 1) / 2).ω
+ exponent = ((ω(monomial) + 1) / 2).ω
integral = np.prod(jtu.tree_map(lambda m: alpha**-m * gamma(m), exponent))
return integral / normalization_constant
diff --git a/tests/test_kernels.py b/tests/test_kernels.py
index b0da4bd..6007b1f 100644
--- a/tests/test_kernels.py
+++ b/tests/test_kernels.py
@@ -12,7 +12,6 @@
RealScalarLike,
RecombinedParticles,
)
-from mccube._kernels.random import MonteCarloKernel
# _kernels/base.py
@@ -25,6 +24,7 @@ def __call__(
args: Args,
weighted: bool = False,
) -> PartitionedParticles:
+ del t, args, weighted
return jtu.tree_map(
lambda p, c: p.reshape(-1, c, p.shape[-1]),
particles,
@@ -39,6 +39,7 @@ def __call__(
args: Args,
weighted: bool = False,
) -> RecombinedParticles:
+ del t, args, weighted
return jtu.tree_map(lambda p, c: p[:c], particles, self.recombination_count)
y0 = jnp.array([[2.0, 4.0, 6.0, 8.0]]).T
@@ -72,7 +73,7 @@ def test_monte_carlo_partitioning_kernel():
y0 = jnp.array([[1.0, 0.01], [2.0, 1.0], [3.0, 100.0], [4.0, 10000.0]])
key = jr.key(42)
- mc_kernel = MonteCarloKernel(None, key=key)
+ mc_kernel = mccube.MonteCarloKernel(None, key=key)
kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel)
values = kernel(0.0, y0, ...)
assert values.shape == (n_parts, y0.shape[0] // n_parts, y0.shape[-1])
@@ -81,7 +82,7 @@ def test_monte_carlo_partitioning_kernel():
)
key = jr.key(42)
- mc_kernel = MonteCarloKernel(None, weighting_function=lambda x: x, key=key)
+ mc_kernel = mccube.MonteCarloKernel(None, weighting_function=lambda x: x, key=key)
kernel = mccube.MonteCarloPartitioningKernel(n_parts, mc_kernel)
values = kernel(0.0, y0, ..., weighted=True)
assert values.shape == (n_parts, y0.shape[0] // n_parts, y0.shape[-1])
@@ -103,9 +104,9 @@ def test_stratified_partitioning_kernel():
# fmt: off
expected_partitioning = jnp.array(
[
- [[-1.0], [1.0]],
- [[-2.0], [2.0]],
- [[-3.0], [3.0]],
+ [[-1.0], [1.0]],
+ [[-2.0], [2.0]],
+ [[-3.0], [3.0]],
[[-4.0], [4.0]]
]
)
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index e5c20dc..5015628 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -40,5 +40,4 @@ def test_pairwise_metric():
expected_dist = jnp.array(
[[0.0, 5.0, 13.0], [5.0, 0.0, jnp.sqrt(68)], [13.0, jnp.sqrt(68), 0.0]]
)
- print(expected_dist)
assert eqx.tree_equal(dist, expected_dist)
diff --git a/tests/test_path.py b/tests/test_path.py
index 86213f7..9a0ee54 100644
--- a/tests/test_path.py
+++ b/tests/test_path.py
@@ -13,11 +13,12 @@ def test_local_linear_cubature_path(formula, dimension):
f = formula(mccube.GaussianRegion(dimension))
path = mccube.LocalLinearCubaturePath(f)
p_t0 = path.evaluate(0.0)
- p_t1 = path.evaluate(1.0)
- p_dt = path.evaluate(0.0, 1.0)
- p_dt2 = path.evaluate(0.0, 0.5)
- weights = path.weights
+ p_t1 = jnp.asarray(path.evaluate(1.0))
+ p_dt = jnp.asarray(path.evaluate(0.0, 1.0))
+ p_dt2 = jnp.asarray(path.evaluate(0.0, 0.5))
+ weights = jnp.asarray(path.weights)
+
assert eqx.tree_equal(jnp.abs(p_t0), jnp.zeros(f.stacked_points.shape))
- assert eqx.tree_equal(p_t1, p_dt, jnp.astype(f.stacked_points, p_t1.dtype)) # type: ignore
- assert eqx.tree_equal(p_dt2, jnp.astype(jnp.sqrt(0.5) * p_dt, p_dt2.dtype)) # type: ignore
- assert eqx.tree_equal(weights, jnp.astype(f.stacked_weights, weights.dtype)) # type: ignore
+ assert eqx.tree_equal(p_t1, p_dt, jnp.astype(f.stacked_points, p_t1.dtype))
+ assert eqx.tree_equal(p_dt2, jnp.astype(jnp.sqrt(0.5) * p_dt, p_dt2.dtype))
+ assert eqx.tree_equal(weights, jnp.astype(f.stacked_weights, weights.dtype))
diff --git a/tests/test_solvers.py b/tests/test_solvers.py
index dfcbc98..a644fa2 100644
--- a/tests/test_solvers.py
+++ b/tests/test_solvers.py
@@ -83,7 +83,8 @@ def test_MCCSolver_ula(formula):
sol = diffeqsolve(
terms, solver, t0, t1, dt0, y0, saveat=SaveAt(dense=True, t1=True)
)
- assert sol.ys.shape == (1, k, d) # type: ignore
+ assert sol.ys is not None
+ assert sol.ys.shape == (1, k, d)
assert sol.evaluate(t1).shape == (k, d)
n_substeps = 2
@@ -96,7 +97,8 @@ def test_MCCSolver_ula(formula):
sol2 = diffeqsolve(
terms, solver2, t0, t1, dt0, y0, saveat=SaveAt(ts=ts, dense=True)
)
- assert sol2.ys.shape == (ts.shape[0], k, d) # type: ignore
+ assert sol2.ys is not None
+ assert sol2.ys.shape == (ts.shape[0], k, d)
assert sol2.evaluate(t0 + dt0).shape == (k, d)
# Test Weighted Particles
@@ -116,6 +118,8 @@ def test_MCCSolver_ula(formula):
y0_weighted,
saveat=SaveAt(t1=True),
)
- assert sol_weighted.ys.shape == (1, k, d + 1) # type: ignore
- particles, weights = mccube.unpack_particles(sol_weighted.ys[0], weighted=True) # type: ignore
- assert eqx.tree_equal(weights.sum(), jnp.array(1.0), rtol=1e-5, atol=1e-8) # type: ignore
+ assert sol_weighted.ys is not None
+ assert sol_weighted.ys.shape == (1, k, d + 1)
+ particles, weights = mccube.unpack_particles(sol_weighted.ys[0], weighted=True)
+ assert weights is not None
+ assert eqx.tree_equal(weights.sum(), jnp.array(1.0), rtol=1e-5, atol=1e-8)
diff --git a/tests/test_term.py b/tests/test_term.py
index cfbc67a..3f23ab2 100644
--- a/tests/test_term.py
+++ b/tests/test_term.py
@@ -7,9 +7,11 @@
def test_mcc_term():
def ode_vector_field(t, y, args):
+ del t, args
return {"y": -y["y"]}
def cde_vector_field(t, y, args):
+ del t, y, args
return {"y": 1.0}
class Control(diffrax.AbstractPath):
@@ -17,9 +19,11 @@ class Control(diffrax.AbstractPath):
t1 = 1
def evaluate(self, t0, t1=None, left=True):
+ del t0, t1, left
return {"y": jnp.ones((8, 2))}
def derivative(self, t, left=True):
+ del t, left
return {"y": jnp.zeros((8, 2))}
control = Control()
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 16f1039..a98d0bd 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,7 +1,9 @@
import equinox as eqx
import jax.numpy as jnp
+import pytest
import mccube
+from mccube._utils import requires_weighing
def test_pack_pacticles():
@@ -23,3 +25,9 @@ def test_unpack_particles():
weights = x[:, -1]
assert eqx.tree_equal((x, None), mccube.unpack_particles(x, False))
assert eqx.tree_equal((x[:, :-1], weights), mccube.unpack_particles(x, True))
+
+
+def test_requires_weighing():
+ requires_weighing(is_weighted=True)
+ with pytest.raises(ValueError, match="Kernel requires `weighted=True`"):
+ requires_weighing(is_weighted=False)