-
Notifications
You must be signed in to change notification settings - Fork 225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for linear-time mmd estimator. #475
base: master
Are you sure you want to change the base?
Changes from 14 commits
f4d2692
8ef820d
ba29712
eaa2e45
ea97f52
f6b93d6
59110ec
7bacbec
7507276
29cd155
eef8def
678eae0
7952ab3
016f23f
05626ec
20e442c
3b96ad4
95f3de4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,10 +4,10 @@ | |
from alibi_detect.utils.frameworks import has_pytorch, has_tensorflow | ||
|
||
if has_pytorch: | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch | ||
from alibi_detect.cd.pytorch.mmd import MMDDriftTorch, LinearTimeMMDDriftTorch | ||
|
||
if has_tensorflow: | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF | ||
from alibi_detect.cd.tensorflow.mmd import MMDDriftTF, LinearTimeMMDDriftTF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -18,6 +18,7 @@ def __init__( | |
x_ref: Union[np.ndarray, list], | ||
backend: str = 'tensorflow', | ||
p_val: float = .05, | ||
estimator: str = 'quad', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added extra description in the docstring. |
||
preprocess_x_ref: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
preprocess_fn: Optional[Callable] = None, | ||
|
@@ -40,6 +41,11 @@ def __init__( | |
Backend used for the MMD implementation. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
estimator | ||
Estimator used for the MMD^2 computation {'quad', 'linear'}. 'Quad' is the default and | ||
uses the quadratic u-statistics on each square kernel matrix. 'Linear' uses the linear | ||
time estimator as in Gretton et al. (JMLR 2014, sec 6), and the threshold is computed | ||
using the Gaussian asympotic distribution under null. | ||
preprocess_x_ref | ||
Whether to already preprocess and store the reference data. | ||
update_x_ref | ||
|
@@ -56,7 +62,8 @@ def __init__( | |
configure_kernel_from_x_ref | ||
Whether to already configure the kernel bandwidth from the reference data. | ||
n_permutations | ||
Number of permutations used in the permutation test. | ||
Number of permutations used in the permutation test, only used for the quadratic estimator | ||
(estimator='quad'). | ||
device | ||
Device type used. The default None tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. Only relevant for 'pytorch' backend. | ||
|
@@ -76,7 +83,7 @@ def __init__( | |
|
||
kwargs = locals() | ||
args = [kwargs['x_ref']] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__'] | ||
pop_kwargs = ['self', 'x_ref', 'backend', '__class__', 'estimator'] | ||
[kwargs.pop(k, None) for k in pop_kwargs] | ||
|
||
if kernel is None: | ||
|
@@ -88,9 +95,21 @@ def __init__( | |
|
||
if backend == 'tensorflow' and has_tensorflow: | ||
kwargs.pop('device', None) | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTF(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Best to clarify in the docstrings that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
self._detector = LinearTimeMMDDriftTF(*args, **kwargs) # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the logic to set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, will modify the tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simply rewrite the test to go through different |
||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
else: | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
if estimator == 'quad': | ||
self._detector = MMDDriftTorch(*args, **kwargs) # type: ignore | ||
elif estimator == 'linear': | ||
kwargs.pop('n_permutations', None) | ||
self._detector = LinearTimeMMDDriftTorch(*args, **kwargs) # type: ignore | ||
else: | ||
raise NotImplementedError(f'{estimator} not implemented. Use quad or linear instead.') | ||
self.meta = self._detector.meta | ||
|
||
def predict(self, x: Union[np.ndarray, list], return_p_val: bool = True, return_distance: bool = True) \ | ||
|
@@ -128,6 +147,7 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | |
Returns | ||
------- | ||
p-value obtained from the permutation test, the MMD^2 between the reference and test set | ||
and the MMD^2 values from the permutation test. | ||
and the MMD^2 values from the qudratic permutation test, or the threshold for the given | ||
significance level for the linear time test. | ||
""" | ||
return self._detector.score(x) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import logging | ||
import numpy as np | ||
import scipy.stats as stats | ||
import torch | ||
from typing import Callable, Dict, Optional, Tuple, Union | ||
from alibi_detect.cd.base import BaseMMDDrift | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix | ||
from alibi_detect.utils.pytorch.distance import mmd2_from_kernel_matrix, linear_mmd2 | ||
from alibi_detect.utils.pytorch.kernels import GaussianRBF | ||
|
||
logger = logging.getLogger(__name__) | ||
|
@@ -118,17 +119,150 @@ def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | |
and the MMD^2 values from the permutation test. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
x_ref = torch.from_numpy(x_ref).to(self.device) # type: ignore[assignment] | ||
x = torch.from_numpy(x).to(self.device) # type: ignore[assignment] | ||
# compute kernel matrix, MMD^2 and apply permutation test using the kernel matrix | ||
n = x.shape[0] | ||
kernel_mat = self.kernel_matrix(x_ref, x) # type: ignore[arg-type] | ||
kernel_mat = kernel_mat - torch.diag(kernel_mat.diag()) # zero diagonal | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) | ||
mmd2 = mmd2_from_kernel_matrix(kernel_mat, n, permute=False, zero_diag=False) # type: ignore[assignment] | ||
mmd2_permuted = torch.Tensor( | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) for _ in range(self.n_permutations)] | ||
) | ||
[mmd2_from_kernel_matrix(kernel_mat, n, permute=True, zero_diag=False) | ||
for _ in range(self.n_permutations)] | ||
) | ||
if self.device.type == 'cuda': | ||
mmd2, mmd2_permuted = mmd2.cpu(), mmd2_permuted.cpu() | ||
p_val = (mmd2 <= mmd2_permuted).float().mean() | ||
return p_val.numpy().item(), mmd2.numpy().item(), mmd2_permuted.numpy() | ||
|
||
|
||
class LinearTimeMMDDriftTorch(BaseMMDDrift): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since these new subclasses don't make use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point. The default number of permutations then can be initialised in |
||
def __init__( | ||
self, | ||
x_ref: Union[np.ndarray, list], | ||
p_val: float = .05, | ||
preprocess_x_ref: bool = True, | ||
update_x_ref: Optional[Dict[str, int]] = None, | ||
preprocess_fn: Optional[Callable] = None, | ||
kernel: Callable = GaussianRBF, | ||
sigma: Optional[np.ndarray] = None, | ||
configure_kernel_from_x_ref: bool = True, | ||
device: Optional[str] = None, | ||
input_shape: Optional[tuple] = None, | ||
data_type: Optional[str] = None | ||
) -> None: | ||
""" | ||
Maximum Mean Discrepancy (MMD) data drift detector using a linear-time estimator. | ||
|
||
Parameters | ||
---------- | ||
x_ref | ||
Data used as reference distribution. | ||
p_val | ||
p-value used for the significance of the permutation test. | ||
preprocess_x_ref | ||
Whether to already preprocess and store the reference data. | ||
update_x_ref | ||
Reference data can optionally be updated to the last n instances seen by the detector | ||
or via reservoir sampling with size n. For the former, the parameter equals {'last': n} while | ||
for reservoir sampling {'reservoir_sampling': n} is passed. | ||
preprocess_fn | ||
Function to preprocess the data before computing the data drift metrics. | ||
kernel | ||
Kernel used for the MMD computation, defaults to Gaussian RBF kernel. | ||
sigma | ||
Optionally set the GaussianRBF kernel bandwidth. Can also pass multiple bandwidth values as an array. | ||
The kernel evaluation is then averaged over those bandwidths. | ||
configure_kernel_from_x_ref | ||
Whether to already configure the kernel bandwidth from the reference data. | ||
device | ||
Device type used. The default None tries to use the GPU and falls back on CPU if needed. | ||
Can be specified by passing either 'cuda', 'gpu' or 'cpu'. | ||
input_shape | ||
Shape of input data. | ||
data_type | ||
Optionally specify the data type (tabular, image or time-series). Added to metadata. | ||
""" | ||
super().__init__( | ||
x_ref=x_ref, | ||
p_val=p_val, | ||
preprocess_x_ref=preprocess_x_ref, | ||
update_x_ref=update_x_ref, | ||
preprocess_fn=preprocess_fn, | ||
sigma=sigma, | ||
configure_kernel_from_x_ref=configure_kernel_from_x_ref, | ||
input_shape=input_shape, | ||
data_type=data_type | ||
) | ||
self.meta.update({'backend': 'pytorch'}) | ||
|
||
# set backend | ||
if device is None or device.lower() in ['gpu', 'cuda']: | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
if self.device.type == 'cpu': | ||
print('No GPU detected, fall back on CPU.') | ||
else: | ||
self.device = torch.device('cpu') | ||
|
||
# initialize kernel | ||
sigma = torch.from_numpy(sigma).to(self.device) if isinstance(sigma, # type: ignore[assignment] | ||
np.ndarray) else None | ||
self.kernel = kernel(sigma) if kernel == GaussianRBF else kernel | ||
|
||
# compute kernel matrix for the reference data | ||
if self.infer_sigma or isinstance(sigma, torch.Tensor): | ||
n = self.x_ref.shape[0] | ||
n_hat = int(np.floor(n / 2) * 2) | ||
x = torch.from_numpy(self.x_ref[:n_hat, :]).to(self.device) | ||
self.k_xx = self.kernel(x=x[0::2, :], y=x[1::2, :], | ||
pairwise=False, infer_sigma=self.infer_sigma) | ||
self.infer_sigma = False | ||
else: | ||
self.k_xx, self.infer_sigma = None, True | ||
|
||
def kernel_matrix(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Method is not used I believe? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @arnaudvl The base class requires this method for initialisation, I wonder what would be the preferable solution here? the minimal thing could be to simply leave a pseudo method. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO we should remove |
||
""" Compute and return full kernel matrix between arrays x and y. """ | ||
k_xy = self.kernel(x, y, self.infer_sigma) | ||
k_xx = self.k_xx if self.k_xx is not None and self.update_x_ref is None else self.kernel(x, x) | ||
k_yy = self.kernel(y, y) | ||
kernel_mat = torch.cat([torch.cat([k_xx, k_xy], 1), torch.cat([k_xy.T, k_yy], 1)], 0) | ||
return kernel_mat | ||
|
||
def score(self, x: Union[np.ndarray, list]) -> Tuple[float, float, np.ndarray]: | ||
""" | ||
Compute the p-value using the maximum mean discrepancy as a distance measure between the | ||
reference data and the data to be tested. x and x_ref are required to have the same size. | ||
The sample size is then specified as the maximal even number below the data size. | ||
|
||
Parameters | ||
---------- | ||
x | ||
Batch of instances. | ||
|
||
Returns | ||
------- | ||
p-value obtained from the null hypothesis, the MMD^2 between the reference and test set | ||
and the MMD^2 threshold for the given significance level. | ||
""" | ||
x_ref, x = self.preprocess(x) | ||
n = x.shape[0] | ||
m = x_ref.shape[0] | ||
if n != m: | ||
raise ValueError('x and x_ref must have the same size.') | ||
n_hat = int(np.floor(n / 2) * 2) | ||
x_ref = torch.from_numpy(x_ref[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe there is a case to be made that there is an explicit check such that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like quite an issue atm. Agree the safest option would be to explicitly check for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently implemented as raise error for |
||
x = torch.from_numpy(x[:n_hat, :]).to(self.device) # type: ignore[assignment] | ||
if self.k_xx is not None and self.update_x_ref is None: | ||
k_xx = self.k_xx | ||
else: | ||
k_xx = self.kernel(x=x_ref[0::2, :], y=x_ref[1::2, :], pairwise=False) | ||
mmd2, var_mmd2 = linear_mmd2(k_xx, x_ref, x, self.kernel) # type: ignore[arg-type] | ||
if self.device.type == 'cuda': | ||
mmd2 = mmd2.cpu() | ||
mmd2 = mmd2.numpy().item() | ||
var_mmd2 = np.clip(var_mmd2.numpy().item(), 1e-8, 1e8) | ||
std_mmd2 = np.sqrt(var_mmd2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can directly use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new version uses |
||
t = mmd2 / (std_mmd2 / np.sqrt(n_hat / 2.)) | ||
p_val = 1 - stats.t.cdf(t, df=(n_hat / 2.) - 1) | ||
distance_threshold = stats.t.ppf(1 - self.p_val, df=(n_hat / 2.) - 1) | ||
return p_val, t, distance_threshold |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from functools import partial | ||
from itertools import product | ||
import numpy as np | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from typing import Callable, List | ||
from alibi_detect.cd.pytorch.mmd import LinearTimeMMDDriftTorch | ||
from alibi_detect.cd.pytorch.preprocess import HiddenOutput, preprocess_drift | ||
|
||
n, n_hidden, n_classes = 500, 10, 5 | ||
|
||
|
||
class MyModel(nn.Module): | ||
def __init__(self, n_features: int): | ||
super().__init__() | ||
self.dense1 = nn.Linear(n_features, 20) | ||
self.dense2 = nn.Linear(20, 2) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = nn.ReLU()(self.dense1(x)) | ||
return self.dense2(x) | ||
|
||
|
||
# test List[Any] inputs to the detector | ||
def preprocess_list(x: List[np.ndarray]) -> np.ndarray: | ||
return np.concatenate(x, axis=0) | ||
|
||
|
||
n_features = [10] | ||
n_enc = [None, 3] | ||
preprocess = [ | ||
(None, None), | ||
(preprocess_drift, {'model': HiddenOutput, 'layer': -1}), | ||
(preprocess_list, None) | ||
] | ||
update_x_ref = [{'last': 500}, {'reservoir_sampling': 500}, None] | ||
preprocess_x_ref = [True, False] | ||
tests_mmddrift = list(product(n_features, n_enc, preprocess, | ||
update_x_ref, preprocess_x_ref)) | ||
n_tests = len(tests_mmddrift) | ||
|
||
|
||
@pytest.fixture | ||
def mmd_params(request): | ||
return tests_mmddrift[request.param] | ||
|
||
|
||
@pytest.mark.parametrize('mmd_params', list(range(n_tests)), indirect=True) | ||
def test_mmd(mmd_params): | ||
n_features, n_enc, preprocess, update_x_ref, preprocess_x_ref = mmd_params | ||
|
||
np.random.seed(0) | ||
torch.manual_seed(0) | ||
|
||
x_ref = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) | ||
preprocess_fn, preprocess_kwargs = preprocess | ||
to_list = False | ||
if hasattr(preprocess_fn, '__name__') and preprocess_fn.__name__ == 'preprocess_list': | ||
if not preprocess_x_ref: | ||
return | ||
to_list = True | ||
x_ref = [_[None, :] for _ in x_ref] | ||
elif isinstance(preprocess_fn, Callable) and 'layer' in list(preprocess_kwargs.keys()) \ | ||
and preprocess_kwargs['model'].__name__ == 'HiddenOutput': | ||
model = MyModel(n_features) | ||
layer = preprocess_kwargs['layer'] | ||
preprocess_fn = partial(preprocess_fn, model=HiddenOutput(model=model, layer=layer)) | ||
else: | ||
preprocess_fn = None | ||
|
||
cd = LinearTimeMMDDriftTorch( | ||
x_ref=x_ref, | ||
p_val=.05, | ||
preprocess_x_ref=preprocess_x_ref if isinstance(preprocess_fn, Callable) else False, | ||
update_x_ref=update_x_ref, | ||
preprocess_fn=preprocess_fn | ||
) | ||
x = x_ref.copy() | ||
preds = cd.predict(x, return_p_val=True) | ||
assert preds['data']['is_drift'] == 0 and preds['data']['p_val'] >= cd.p_val | ||
if isinstance(update_x_ref, dict): | ||
k = list(update_x_ref.keys())[0] | ||
assert cd.n == len(x) + len(x_ref) | ||
assert cd.x_ref.shape[0] == min(update_x_ref[k], len(x) + len(x_ref)) | ||
|
||
x_h1 = np.random.randn(n * n_features).reshape(n, n_features).astype(np.float32) | ||
if to_list: | ||
x_h1 = [_[None, :] for _ in x_h1] | ||
preds = cd.predict(x_h1, return_p_val=True) | ||
if preds['data']['is_drift'] == 1: | ||
assert preds['data']['p_val'] < preds['data']['threshold'] == cd.p_val | ||
assert preds['data']['distance'] > preds['data']['distance_threshold'] | ||
else: | ||
assert preds['data']['p_val'] >= preds['data']['threshold'] == cd.p_val | ||
assert preds['data']['distance'] <= preds['data']['distance_threshold'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't seem ideal to have to deal with the differing behaviour of
score
for the original vs new detectors inpredict
.Maybe we could move the
distance_threshold
computation toscore
for the original MMD detectors, and then the above would be simplified quite a bit? Draft PR for this here: #489There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, guess the best thing to do here is to follow your draft PR's template to modify the linear-time detector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd say if @arnaudvl and @ojcobb agree with the change in #489 , we should merge that, then it would be a super quick change to this PR.