Skip to content

Commit

Permalink
[ENH] mixture of distributions (#26)
Browse files Browse the repository at this point in the history
Implements mixture of distributions.

Towards #22, and required for
ensemble regressor.

Also adds a default implementation for `ppf` in the `BaseDistribution`,
using the bisection method to invert a `cdf`, if present.
  • Loading branch information
fkiraly authored Aug 25, 2023
1 parent 1146d6e commit b531d7e
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 4 deletions.
8 changes: 7 additions & 1 deletion skpro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
# adapted from sktime

__all__ = ["Empirical", "Laplace", "Normal"]
__all__ = [
"Empirical",
"Laplace",
"Mixture",
"Normal",
]

from skpro.distributions.empirical import Empirical
from skpro.distributions.laplace import Laplace
from skpro.distributions.mixture import Mixture
from skpro.distributions.normal import Normal
42 changes: 39 additions & 3 deletions skpro/distributions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BaseDistribution(BaseObject):
"approx_var_spl": 1000, # sample size used in MC estimates of var
"approx_energy_spl": 1000, # sample size used in MC estimates of energy
"approx_spl": 1000, # sample size used in other MC estimates
"bisect_iter": 1000, # max iters for bisection method in ppf
}

def __init__(self, index=None, columns=None):
Expand Down Expand Up @@ -92,9 +93,12 @@ def _subset_params(self, rowidx, colidx):

subset_param_dict = {}
for param, val in params.items():
arr = np.array(val)
if len(arr.shape) == 0:
subset_param_dict
if val is not None:
arr = np.array(val)
else:
arr = None
# if len(arr.shape) == 0:
# do nothing with arr
if len(arr.shape) >= 1 and rowidx is not None:
arr = arr[rowidx]
if len(arr.shape) >= 2 and colidx is not None:
Expand Down Expand Up @@ -252,6 +256,38 @@ def cdf(self, x):

def ppf(self, p):
"""Quantile function = percent point function = inverse cdf."""
if self._has_implementation_of("cdf"):
from scipy.optimize import bisect

max_iter = self.get_tag("bisect_iter")
approx_method = (
"by using the bisection method (scipy.optimize.bisect) on "
f"the cdf, at {max_iter} maximum iterations"
)
warn(self._method_error_msg("cdf", fill_in=approx_method))

result = pd.DataFrame(index=p.index, columns=p.columns, dtype="float")
for ix in p.index:
for col in p.columns:
d_ix = self.loc[[ix], [col]]
p_ix = p.loc[ix, col]

def opt_fun(x):
"""Optimization function, to find x s.t. cdf(x) = p_ix."""
x = pd.DataFrame(x, index=[ix], columns=[col]) # noqa: B023
return d_ix.cdf(x).values[0][0] - p_ix # noqa: B023

left_bd = -1e6
right_bd = 1e6
while opt_fun(left_bd) > 0:
left_bd *= 10
while opt_fun(right_bd) < 0:
right_bd *= 10
result.loc[ix, col] = bisect(
opt_fun, left_bd, right_bd, maxiter=max_iter
)
return result

raise NotImplementedError(self._method_error_msg("ppf", "error"))

def energy(self, x=None):
Expand Down
197 changes: 197 additions & 0 deletions skpro/distributions/mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
# copyright: skpro developers, BSD-3-Clause License (see LICENSE file)
"""Mixture distribution."""

__author__ = ["fkiraly"]

import numpy as np
import pandas as pd
from skbase.base import BaseMetaObject

from skpro.distributions.base import BaseDistribution


class Mixture(BaseMetaObject, BaseDistribution):
"""Mixture of distributions.
Parameters
----------
distributions : list of tuples (str, BaseDistribution) or BaseDistribution
list of mixture components
weights : list of float, optional, default = None
list of mixture weights, will be normalized to sum to 1
if not provided, uniform mixture is assumed
index : pd.Index, optional, default = inferred from component distributions
columns : pd.Index, optional, default = inferred from component distributions
Example
-------
>>> from skpro.distributions.mixture import Mixture
>>> from skpro.distributions.normal import Normal
>>> n1 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1)
>>> n2 = Normal(mu=3, sigma=2, index=n1.index, columns=n1.columns)
>>> m = Mixture(distributions=[("n1", n1), ("n2", n2)], weights=[0.3, 0.7])
>>> mixture_sample = m.sample(n_samples=10)
"""

_tags = {
"capabilities:approx": ["pdfnorm", "energy", "ppf"],
"capabilities:exact": ["mean", "var", "pdf", "log_pdf", "cdf"],
"distr:measuretype": "mixed",
"named_object_parameters": "_distributions",
}

def __init__(self, distributions, weights=None, index=None, columns=None):

self.distributions = distributions
self.weights = weights
self.index = index
self.columns = columns

self._distributions = self._coerce_to_named_object_tuples(distributions)
n_dists = len(self._distributions)

if weights is None:
self._weights = np.ones(n_dists) / n_dists
else:
self._weights = np.array(weights) / np.sum(weights)

if index is None:
index = self._distributions[0][1].index

if columns is None:
columns = self._distributions[0][1].columns

super().__init__(index=index, columns=columns)

def _iloc(self, rowidx=None, colidx=None):

dists = self._distributions
weights = self.weights

dists_subset = [(x[0], x[1].iloc[rowidx, colidx]) for x in dists]

index_subset = dists_subset[0][1].index
columns_subset = dists_subset[0][1].columns

return Mixture(
distributions=dists_subset,
weights=weights,
index=index_subset,
columns=columns_subset,
)

def mean(self):
r"""Return expected value of the distribution.
Let :math:`X` be a random variable with the distribution of `self`.
Returns the expectation :math:`\mathbb{E}[X]`
Returns
-------
pd.DataFrame with same rows, columns as `self`
expected value of distribution (entry-wise)
"""
return self._average("mean")

def var(self):
r"""Return element/entry-wise variance of the distribution.
Let :math:`X` be a random variable with the distribution of `self`.
Returns :math:`\mathbb{V}[X] = \mathbb{E}\left(X - \mathbb{E}[X]\right)^2`
Returns
-------
pd.DataFrame with same rows, columns as `self`
variance of distribution (entry-wise)
"""
weights = self._weights
var_mean = self._average("var")
mixture_mean = self._average("mean")

means = [d.mean() for _, d in self._distributions]
mean_var = [(m - mixture_mean) ** 2 for m in means]
var_mean_var = self._average_df(mean_var, weights=weights)

return var_mean + var_mean_var

def _average(self, method, x=None, weights=None):
"""Average a method over the mixture components."""
if x is None:
args = ()
else:
args = (x,)

vals = [getattr(d, method)(*args) for _, d in self._distributions]

return self._average_df(vals, weights=weights)

def _average_df(self, df_list, weights=None):
"""Average a list of `pd.DataFrame` objects, with weights."""
if weights is None and hasattr(self, "_weights"):
weights = self._weights
elif weights is None:
weights = np.ones(len(df_list)) / len(df_list)

n_df = len(df_list)
df_weighted = [df * w for df, w in zip(df_list, weights)]
df_concat = pd.concat(df_weighted, axis=1, keys=range(n_df))
df_res = df_concat.groupby(level=-1, axis=1).sum()
return df_res

def pdf(self, x):
"""Probability density function."""
return self._average("pdf", x)

def cdf(self, x):
"""Cumulative distribution function."""
return self._average("cdf", x)

def sample(self, n_samples=None):
"""Sample from the distribution.
Parameters
----------
n_samples : int, optional, default = None
Returns
-------
if `n_samples` is `None`:
returns a sample that contains a single sample from `self`,
in `pd.DataFrame` mtype format convention, with `index` and `columns` as `self`
if n_samples is `int`:
returns a `pd.DataFrame` that contains `n_samples` i.i.d. samples from `self`,
in `pd-multiindex` mtype format convention, with same `columns` as `self`,
and `MultiIndex` that is product of `RangeIndex(n_samples)` and `self.index`
"""
if n_samples is None:
N = 1
else:
N = n_samples

n_dist = len(self._distributions)
selector = np.random.choice(n_dist, size=N, p=self._weights)

samples = [self._distributions[i][1].sample() for i in selector]

if n_samples is None:
return samples[0]
else:
return pd.concat(samples, axis=0, keys=range(N))

@classmethod
def get_test_params(cls, parameter_set="default"):
"""Return testing parameter settings for the estimator."""
from skpro.distributions.normal import Normal

index = pd.RangeIndex(3)
columns = pd.Index(["a", "b"])
normal1 = Normal(mu=0, sigma=1, index=index, columns=columns)
normal2 = Normal(mu=[[0, 1], [2, 3], [4, 5]], sigma=1, columns=columns)

dists = [("normal1", normal1), ("normal2", normal2)]

params1 = {"distributions": dists}
params2 = {"distributions": dists, "weights": [0.3, 0.7]}
return [params1, params2]
3 changes: 3 additions & 0 deletions skpro/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,11 @@ class PackageConfig:
"approx_var_spl", # int, sample size used in MC estimates of var
"approx_energy_spl", # int, sample size used in MC estimates of energy
"approx_spl", # int, sample size used in other MC estimates
"bisect_iter", # max iters for bisection method in ppf
"scitype:y_pred", # str, expected input type for y_pred in performance metric
"lower_is_better", # bool, whether lower (True) or higher (False) is better
# BaseMetaObject reserved tags
"named_object_parameters", # name of component list attribute for meta-objects
]


Expand Down

0 comments on commit b531d7e

Please sign in to comment.