Skip to content

Commit

Permalink
Improve OOB notebook and some helper and plotting functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito committed Sep 15, 2023
1 parent e760a36 commit 99095a3
Show file tree
Hide file tree
Showing 3 changed files with 386 additions and 246 deletions.
440 changes: 200 additions & 240 deletions notebooks/data_oob.ipynb

Large diffs are not rendered by default.

106 changes: 105 additions & 1 deletion notebooks/support/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import logging
import os
import pickle
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -551,3 +554,104 @@ def plot_corrupted_influences_distribution(
axes[idx].set_title(f"Influences for {label=}")
axes[idx].legend()
plt.show()


def filecache(path: Path) -> Callable[[Callable], Callable]:
"""Wraps a function to cache its output on disk.
There is no hashing of the arguments of the function. This function merely
checks whether `filename` exists and if so, loads the output from it, and if
not it calls the function and saves the output to `filename`.
Args:
fun: Function to wrap.
filename: Name of the file to cache the output to.
Returns:
The wrapped function.
"""

def decorator(fun: Callable) -> Callable:
@wraps(fun)
def wrapper(*args, **kwargs):
try:
with path.open("rb") as fd:
print(f"Found cached file: {path.name}.")
return pickle.load(fd)
except (FileNotFoundError, EOFError, pickle.UnpicklingError):
result = fun(*args, **kwargs)
with path.open("wb") as fd:
pickle.dump(result, fd)
return result

return wrapper

return decorator


@filecache(path=Path("adult_data.pkl"))
def load_adult_data():
data_url = (
"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
)

column_names = [
"age",
"workclass",
"fnlwgt",
"education",
"education-num",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"capital-gain",
"capital-loss",
"hours-per-week",
"native-country",
"income",
]

data_types = {
"age": int,
"workclass": "category",
"fnlwgt": int,
"education": "category",
"education-num": int,
"marital-status": "category",
"occupation": "category",
"relationship": "category",
"race": "category",
"sex": "category",
"capital-gain": int,
"capital-loss": int,
"hours-per-week": int,
"native-country": "category",
"income": "category",
}

data_adult = pd.read_csv(
data_url,
names=column_names,
sep=",\s*",
engine="python",
na_values="?",
dtype=data_types,
nrows=2000,
)

# Drop categorical columns
data_adult = data_adult.drop(
columns=[
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"native-country",
]
)

return data_adult
86 changes: 81 additions & 5 deletions src/pydvl/reporting/plots.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from typing import Any, List, Optional, OrderedDict, Sequence
from functools import partial
from typing import Any, List, Literal, Optional, OrderedDict, Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from deprecate import deprecated
from matplotlib.axes import Axes
from numpy.typing import NDArray
from scipy.stats import norm
from scipy.stats import norm, t

from pydvl.value import ValuationResult


@deprecated(
target=None,
deprecated_in="0.7.1",
remove_in="0.9.0",
)
def shaded_mean_std(
data: np.ndarray,
abscissa: Optional[Sequence[Any]] = None,
Expand All @@ -21,7 +30,12 @@ def shaded_mean_std(
ax: Optional[Axes] = None,
**kwargs,
) -> Axes:
"""The usual mean \(\pm\) std deviation plot to aggregate runs of experiments.
r"""The usual mean \(\pm\) std deviation plot to aggregate runs of
experiments.
!!! warning "Deprecation notice"
This function is bogus and will be removed in the future in favour of
properly computed confidence intervals.
Args:
data: axis 0 is to be aggregated on (e.g. runs) and axis 1 is the
Expand Down Expand Up @@ -59,6 +73,67 @@ def shaded_mean_std(
return ax


def plot_values_ci(
values: ValuationResult,
level: float,
type: Literal["normal", "t", "auto"] = "auto",
mean_color: Optional[str] = "dodgerblue",
shade_color: Optional[str] = "lightblue",
ax: Optional[plt.Axes] = None,
**kwargs,
):
"""Plot values and a confidence interval.
Supported intervals are based on the normal and the t distributions.
Args:
values: The valuation result.
level: The confidence level.
type: The type of confidence interval to use. If "auto", uses "norm" if
the minimum number of updates for all indices is greater than 30,
otherwise uses "t".
mean_color: The color of the mean line.
shade_color: The color of the confidence interval.
ax: If passed, axes object into which to insert the figure. Otherwise,
a new figure is created and the axes returned.
**kwargs: Additional arguments to pass to the plot function.
Returns:
The matplotlib axes.
"""

ppfs = {
"normal": norm.ppf,
"t": partial(t.ppf, df=values.counts - 1),
"auto": norm.ppf
if np.min(values.counts) > 30
else partial(t.ppf, df=values.counts - 1),
}

try:
score = ppfs[type](1 - level / 2)
except KeyError:
raise ValueError(
f"Unknown confidence interval type requested: {type}."
) from None

abscissa = np.arange(len(values))
bound = score * values.stderr

if ax is None:
fig, ax = plt.subplots()

ax.fill_between(
abscissa,
values.values - bound,
values.values + bound,
alpha=0.3,
color=shade_color,
)
ax.plot(abscissa, values.values, color=mean_color, **kwargs)
return ax


def spearman_correlation(vv: List[OrderedDict], num_values: int, pvalue: float):
"""Simple matrix plots with spearman correlation for each pair in vv.
Expand Down Expand Up @@ -108,8 +183,9 @@ def plot_shapley(
ylabel: Optional[str] = None,
) -> plt.Axes:
r"""Plots the shapley values, as returned from
[compute_shapley_values][pydvl.value.shapley.common.compute_shapley_values], with error bars
corresponding to an $\alpha$-level confidence interval.
[compute_shapley_values][pydvl.value.shapley.common.compute_shapley_values],
with error bars corresponding to an $\alpha$-level Normal confidence
interval.
Args:
df: dataframe with the shapley values
Expand Down

0 comments on commit 99095a3

Please sign in to comment.