Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for linear-time mmd estimator. #475

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions alibi_detect/cd/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,16 @@ def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_
'data' contains the drift prediction and optionally the p-value, threshold and MMD metric.
"""
# compute drift scores
p_val, dist, dist_permutations = self.score(x)
drift_pred = int(p_val < self.p_val)
p_val, dist, tmp_v = self.score(x)
if len(np.shape(tmp_v)) > 0:
dist_permutations = tmp_v
# compute distance threshold
idx_threshold = int(self.p_val * len(dist_permutations))
distance_threshold = np.sort(dist_permutations)[::-1][idx_threshold]
else:
distance_threshold = tmp_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem ideal to have to deal with the differing behaviour of score for the original vs new detectors in predict.

Maybe we could move the distance_threshold computation to score for the original MMD detectors, and then the above would be simplified quite a bit? Draft PR for this here: #489

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, guess the best thing to do here is to follow your draft PR's template to modify the linear-time detector.

Copy link
Contributor

@ascillitoe ascillitoe Apr 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say if @arnaudvl and @ojcobb agree with the change in #489 , we should merge that, then it would be a super quick change to this PR.


# compute distance threshold
idx_threshold = int(self.p_val * len(dist_permutations))
distance_threshold = np.sort(dist_permutations)[::-1][idx_threshold]
drift_pred = int(p_val < self.p_val)

# update reference dataset
if isinstance(self.update_x_ref, dict) and self.preprocess_fn is not None and self.preprocess_x_ref:
Expand Down
34 changes: 27 additions & 7 deletions alibi_detect/cd/mmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow

if has_pytorch:
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeMMDDriftTorch

if has_tensorflow:
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF

logger = logging.getLogger(__name__)

Expand All @@ -18,6 +18,7 @@ def __init__(
x_ref: Union[np.ndarray, list],
backend: str = 'tensorflow',
p_val: float = .05,
estimator: str = 'quad',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would estimator_complexity be more descriptive? (Or at least make clear in the docstring)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added extra description in the docstring.

preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
Expand All @@ -40,6 +41,11 @@ def __init__(
Backend used for the MMD implementation.
p_val
p-value used for the significance of the permutation test.
estimator
Estimator used for the MMD^2 computation {'quad', 'linear'}. 'Quad' is the default and
uses the quadratic u-statistics on each square kernel matrix. 'Linear' uses the linear
time estimator as in Gretton et al. (JMLR 2014, sec 6), and the threshold is computed
using the Gaussian asympotic distribution under null.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Expand All @@ -56,7 +62,8 @@ def __init__(
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
n_permutations
Number of permutations used in the permutation test.
Number of permutations used in the permutation test, only used for the quadratic estimator
(estimator='quad').
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend.
Expand All @@ -76,7 +83,7 @@ def __init__(

kwargs = locals()
args = [kwargs['x_ref']]
pop_kwargs = ['self', 'x_ref', 'backend', '__class__']
pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator']
[kwargs.pop(k, None) for k in pop_kwargs]

if kernel is None:
Expand All @@ -88,9 +95,21 @@ def __init__(

if backend == 'tensorflow' and has_tensorflow:
kwargs.pop('device', None)
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore
if estimator == 'quad':
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore
elif estimator == 'linear':
kwargs.pop('n_permutations', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Best to clarify in the docstrings that n_permutations is not used for the linear estimator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the logic to set self._detector is located here, we should add additional tests to alibi_detect/cd/tests/test_mmd.py to check that the correct subclass is selected conditional on backend and estimator.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, will modify the tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simply rewrite the test to go through different backend and estimator options, should do the job.

else:
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.')
else:
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore
if estimator == 'quad':
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore
elif estimator == 'linear':
kwargs.pop('n_permutations', None)
self._detector = LinearTimeMMDDriftTorch(*args, **kwargs) # type: ignore
else:
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.')
self.meta = self._detector.meta

def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \
Expand Down Expand Up @@ -128,6 +147,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
Returns
-------
p-value obtained from the permutation test, the MMD^2 between the reference and test set
and the MMD^2 values from the permutation test.
and the MMD^2 values from the qudratic permutation test, or the threshold for the given
significance level for the linear time test.
"""
return self._detector.score(x)
144 changes: 139 additions & 5 deletions alibi_detect/cd/pytorch/mmd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import numpy as np
import scipy.stats as stats
import torch
from typing import Callable, Dict, Optional, Tuple, Union
from alibi_detect.cd.base import BaseMMDDrift
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2
from alibi_detect.utils.pytorch.kernels import GaussianRBF

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,17 +119,150 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
and the MMD^2 values from the permutation test.
"""
x_ref, x = self.preprocess(x)
n = x.shape[0]
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment]
x = torch.from_numpy(x).to(self.device) # type: ignore[assignment]
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix
n = x.shape[0]
kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type]
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False)
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment]
mmd2_permuted = torch.Tensor(
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)]
)
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False)
for _ in range(self.n_permutations)]
)
if self.device.type == 'cuda':
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu()
p_val = (mmd2 <= mmd2_permuted).float().mean()
return p_val.numpy().item(), mmd2.numpy().item(), mmd2_permuted.numpy()


class LinearTimeMMDDriftTorch(BaseMMDDrift):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these new subclasses don't make use of self.n_permutations (set in BaseMMDDrift), shall we set this to None? I had a moment of confusion when updating the tests since self.n_permuations == 100 when estimator == 'linear'.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The default number of permutations then can be initialised in /cd/mmd.py when estimator is 'quad'.

def __init__(
self,
x_ref: Union[np.ndarray, list],
p_val: float = .05,
preprocess_x_ref: bool = True,
update_x_ref: Optional[Dict[str, int]] = None,
preprocess_fn: Optional[Callable] = None,
kernel: Callable = GaussianRBF,
sigma: Optional[np.ndarray] = None,
configure_kernel_from_x_ref: bool = True,
device: Optional[str] = None,
input_shape: Optional[tuple] = None,
data_type: Optional[str] = None
) -> None:
"""
Maximum Mean Discrepancy (MMD) data drift detector using a linear-time estimator.

Parameters
----------
x_ref
Data used as reference distribution.
p_val
p-value used for the significance of the permutation test.
preprocess_x_ref
Whether to already preprocess and store the reference data.
update_x_ref
Reference data can optionally be updated to the last n instances seen by the detector
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while
for reservoir sampling {'reservoir_sampling': n} is passed.
preprocess_fn
Function to preprocess the data before computing the data drift metrics.
kernel
Kernel used for the MMD computation, defaults to Gaussian RBF kernel.
sigma
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array.
The kernel evaluation is then averaged over those bandwidths.
configure_kernel_from_x_ref
Whether to already configure the kernel bandwidth from the reference data.
device
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
Can be specified by passing either 'cuda', 'gpu' or 'cpu'.
input_shape
Shape of input data.
data_type
Optionally specify the data type (tabular, image or time-series). Added to metadata.
"""
super().__init__(
x_ref=x_ref,
p_val=p_val,
preprocess_x_ref=preprocess_x_ref,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn,
sigma=sigma,
configure_kernel_from_x_ref=configure_kernel_from_x_ref,
input_shape=input_shape,
data_type=data_type
)
self.meta.update({'backend': 'pytorch'})

# set backend
if device is None or device.lower() in ['gpu', 'cuda']:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if self.device.type == 'cpu':
print('No GPU detected, fall back on CPU.')
else:
self.device = torch.device('cpu')

# initialize kernel
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment]
np.ndarray) else None
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel

# compute kernel matrix for the reference data
if self.infer_sigma or isinstance(sigma, torch.Tensor):
n = self.x_ref.shape[0]
n_hat = int(np.floor(n / 2) * 2)
x = torch.from_numpy(self.x_ref[:n_hat, :]).to(self.device)
self.k_xx = self.kernel(x=x[0::2, :], y=x[1::2, :],
pairwise=False, infer_sigma=self.infer_sigma)
self.infer_sigma = False
else:
self.k_xx, self.infer_sigma = None, True

def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method is not used I believe?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnaudvl The base class requires this method for initialisation, I wonder what would be the preferable solution here? the minimal thing could be to simply leave a pseudo method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should remove kernel_matrix from BaseMMDDrift, so that it is no longer an abstractmethod. I don't think it makes sense to have it as an abstract method if not all subclasses use/need it.

""" Compute and return full kernel matrix between arrays x and y. """
k_xy = self.kernel(x, y, self.infer_sigma)
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x)
k_yy = self.kernel(y, y)
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0)
return kernel_mat

def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]:
"""
Compute the p-value using the maximum mean discrepancy as a distance measure between the
reference data and the data to be tested. x and x_ref are required to have the same size.
The sample size is then specified as the maximal even number below the data size.

Parameters
----------
x
Batch of instances.

Returns
-------
p-value obtained from the null hypothesis, the MMD^2 between the reference and test set
and the MMD^2 threshold for the given significance level.
"""
x_ref, x = self.preprocess(x)
n = x.shape[0]
m = x_ref.shape[0]
if n != m:
raise ValueError('x and x_ref must have the same size.')
n_hat = int(np.floor(n / 2) * 2)
x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe there is a case to be made that there is an explicit check such that n == m and if not, an error is raised. The reason is that silently some unexpected behaviour can occur by only selecting the first n_hat reference/test instances. If say the reference data is ordered and contains samples from classes 1,2 and 3, then only choosing :n_hat could ignore all samples from class 3 and not form an i.i.d. sample anymore. So my preference would be explicit behaviour around this (raising errors) or if we allow this (which I am not in favour of now) then randomly sample n_hat instances from x_ref and x. Good to have some opinions @jklaise @ascillitoe @ojcobb

Copy link
Contributor

@ascillitoe ascillitoe Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like quite an issue atm. Agree the safest option would be to explicitly check for n == m and raise an error. Otherwise, we could check, and randomly subsample if n != m, with a warning raised to inform the user we are doing this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently implemented as raise error for n!=m. Guess the subsampling should be implemented on a stand-alone part, so that it can be used with other detectors?

x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment]
if self.k_xx is not None and self.update_x_ref is None:
k_xx = self.k_xx
else:
k_xx = self.kernel(x=x_ref[0::2, :], y=x_ref[1::2, :], pairwise=False)
mmd2, var_mmd2 = linear_mmd2(k_xx, x_ref, x, self.kernel) # type: ignore[arg-type]
if self.device.type == 'cuda':
mmd2 = mmd2.cpu()
mmd2 = mmd2.numpy().item()
var_mmd2 = np.clip(var_mmd2.numpy().item(), 1e-8, 1e8)
std_mmd2 = np.sqrt(var_mmd2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can directly use torch.std(...) in linear_mmd2? This would remove the few additional lines of code here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new version uses np.sqrt(np.clip(var_mmd2, 1e-8, 1e-8)) for numeric stability.

t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.))
p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1)
distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1)
return p_val, t, distance_threshold
96 changes: 96 additions & 0 deletions alibi_detect/cd/pytorch/tests/test_linear_time_mmd_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from functools import partial
from itertools import product
import numpy as np
import pytest
import torch
import torch.nn as nn
from typing import Callable, List
from alibi_detect.cd.pytorch.mmd import LinearTimeMMDDriftTorch
from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift

n, n_hidden, n_classes = 500, 10, 5


class MyModel(nn.Module):
def __init__(self, n_features: int):
super().__init__()
self.dense1 = nn.Linear(n_features, 20)
self.dense2 = nn.Linear(20, 2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.ReLU()(self.dense1(x))
return self.dense2(x)


# test List[Any] inputs to the detector
def preprocess_list(x: List[np.ndarray]) -> np.ndarray:
return np.concatenate(x, axis=0)


n_features = [10]
n_enc = [None, 3]
preprocess = [
(None, None),
(preprocess_drift, {'model': HiddenOutput, 'layer': -1}),
(preprocess_list, None)
]
update_x_ref = [{'last': 500}, {'reservoir_sampling': 500}, None]
preprocess_x_ref = [True, False]
tests_mmddrift = list(product(n_features, n_enc, preprocess,
update_x_ref, preprocess_x_ref))
n_tests = len(tests_mmddrift)


@pytest.fixture
def mmd_params(request):
return tests_mmddrift[request.param]


@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True)
def test_mmd(mmd_params):
n_features, n_enc, preprocess, update_x_ref, preprocess_x_ref = mmd_params

np.random.seed(0)
torch.manual_seed(0)

x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
preprocess_fn, preprocess_kwargs = preprocess
to_list = False
if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list':
if not preprocess_x_ref:
return
to_list = True
x_ref = [_[None, :] for _ in x_ref]
elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \
and preprocess_kwargs['model'].__name__ == 'HiddenOutput':
model = MyModel(n_features)
layer = preprocess_kwargs['layer']
preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer))
else:
preprocess_fn = None

cd = LinearTimeMMDDriftTorch(
x_ref=x_ref,
p_val=.05,
preprocess_x_ref=preprocess_x_ref if isinstance(preprocess_fn, Callable) else False,
update_x_ref=update_x_ref,
preprocess_fn=preprocess_fn
)
x = x_ref.copy()
preds = cd.predict(x, return_p_val=True)
assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val
if isinstance(update_x_ref, dict):
k = list(update_x_ref.keys())[0]
assert cd.n == len(x) + len(x_ref)
assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref))

x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32)
if to_list:
x_h1 = [_[None, :] for _ in x_h1]
preds = cd.predict(x_h1, return_p_val=True)
if preds['data']['is_drift'] == 1:
assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val
assert preds['data']['distance'] > preds['data']['distance_threshold']
else:
assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val
assert preds['data']['distance'] <= preds['data']['distance_threshold']
Loading