Skip to content

Commit

Permalink
Merge pull request #616 from aai-institute/feature/refactor-classwise…
Browse files Browse the repository at this point in the history
…-shapley

Refactor Classwise Shapley
  • Loading branch information
schroedk authored Aug 21, 2024
2 parents 6353dc4 + 37acff9 commit 85025c0
Show file tree
Hide file tree
Showing 19 changed files with 922 additions and 167 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Refactor Classwise Shapley valuation with the interfaces and sampler architecture [PR #616](https://github.com/aai-institute/pyDVL/pull/616).
- Refactoring KNN Shapley values with the new sampler architecture [PR #610](https://github.com/aai-institute/pyDVL/pull/610).
- Refactoring MSR Banzhaf semivalues with the new sampler architecture.
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)
Expand Down
168 changes: 71 additions & 97 deletions src/pydvl/valuation/methods/classwise_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,145 +60,115 @@
from __future__ import annotations

import logging
from typing import Callable, Generator
from typing import Any, TypeVar

import numpy as np
from joblib import Parallel, delayed
from numpy.typing import NDArray

from pydvl.utils.progress import Progress
from pydvl.utils.types import SupervisedModel
from pydvl.valuation.base import Valuation
from pydvl.valuation.dataset import Dataset
from pydvl.valuation.result import ValuationResult
from pydvl.valuation.samplers import IndexSampler, PowersetSampler
from pydvl.valuation.samplers.base import EvaluationStrategy
from pydvl.valuation.samplers.powerset import NoIndexIteration
from pydvl.valuation.samplers.classwise import ClasswiseSampler, get_unique_labels
from pydvl.valuation.scorers.classwise import ClasswiseSupervisedScorer
from pydvl.valuation.stopping import StoppingCriterion
from pydvl.valuation.types import BatchGenerator, IndexSetT
from pydvl.valuation.utility.base import UtilityBase
from pydvl.valuation.utility.classwise import CSSample
from pydvl.valuation.utility.modelutility import ModelUtility
from pydvl.valuation.utility.classwise import ClasswiseModelUtility
from pydvl.valuation.utils import (
ensure_backend_has_generator_return,
make_parallel_flag,
)

__all__ = ["ClasswiseShapley"]
__all__ = ["ClasswiseShapleyValuation"]

logger = logging.getLogger(__name__)

T = TypeVar("T")

def unique_labels(array: NDArray) -> NDArray:
"""Labels of the dataset."""
# Object, String, Unicode, Unsigned integer, Signed integer, boolean
if array.dtype.kind in "OSUiub":
return np.unique(array)
raise ValueError("Dataset must be categorical to have unique labels.")

class ClasswiseShapleyValuation(Valuation):
"""Class to compute Class-wise Shapley values.
class ClasswiseSampler(IndexSampler):
def __init__(
self,
in_class: IndexSampler,
out_of_class: PowersetSampler,
label: int | None = None,
):
super().__init__()
self.in_class = in_class
self.out_of_class = out_of_class
self.label = label

def for_label(self, label: int) -> ClasswiseSampler:
return ClasswiseSampler(self.in_class, self.out_of_class, label)

def from_data(self, data: Dataset) -> Generator[list[CSSample], None, None]:
assert self.label is not None

without_label = np.where(data.y != self.label)[0]
with_label = np.where(data.y == self.label)[0]

# HACK: the outer sampler is over full subsets of T_{-y_i}
self.out_of_class._index_iteration = NoIndexIteration

for ooc_batch in self.out_of_class.generate_batches(without_label):
# NOTE: The inner sampler can be a permutation sampler => we need to
# return batches of the same size as that sampler in order for the
# in_class strategy to work correctly.
for ooc_sample in ooc_batch:
for ic_batch in self.in_class.generate_batches(with_label):
# FIXME? this sends the same out_of_class_subset for all samples
# maybe a few 10s of KB... probably irrelevant
yield [
CSSample(
idx=ic_sample.idx,
label=self.label,
subset=ooc_sample.subset,
in_class_subset=ic_sample.subset,
)
for ic_sample in ic_batch
]

def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
raise AttributeError("Cannot sample from indices directly.")

def make_strategy(
self,
utility: UtilityBase,
coefficient: Callable[[int, int], float] | None = None,
) -> EvaluationStrategy[IndexSampler]:
return self.in_class.make_strategy(utility, coefficient)
It proceeds by sampling independent permutations of the index set
for each label and index sets sampled from the powerset of the complement
(with respect to the currently evaluated label).
Args:
utility: Classwise utility object with model and classwise scoring function.
sampler: Classwise sampling scheme to use.
is_done: Stopping criterion to use.
progress: Whether to show a progress bar.
normalize_values: Whether to normalize values after valuation.
"""

algorithm_name = "Classwise-Shapley"

class ClasswiseShapley(Valuation):
def __init__(
self,
utility: ModelUtility[CSSample, SupervisedModel],
utility: ClasswiseModelUtility,
sampler: ClasswiseSampler,
is_done: StoppingCriterion,
progress: bool = False,
progress: dict[str, Any] | bool = False,
*,
normalize_values: bool = True,
):
super().__init__()
self.utility = utility
self.sampler = sampler
self.labels: NDArray | None = None
if not isinstance(utility.scorer, ClasswiseSupervisedScorer):
raise ValueError("Scorer must be a ClasswiseScorer.")
raise ValueError("scorer must be an instance of ClasswiseSupervisedScorer")
self.scorer: ClasswiseSupervisedScorer = utility.scorer
self.is_done = is_done
self.progress = progress
self.tqdm_args: dict[str, Any] = {
"desc": f"{self.__class__.__name__}: {str(is_done)}"
}
# HACK: parse additional args for the progress bar if any (we probably want
# something better)
if isinstance(progress, bool):
self.tqdm_args.update({"disable": not progress})
else:
self.tqdm_args.update(progress if isinstance(progress, dict) else {})
self.normalize_values = normalize_values

def fit(self, data: Dataset):
self.result = ValuationResult.zeros(
# TODO: automate str representation for all Valuations
algorithm=f"classwise-shapley",
algorithm=f"{self.__class__.__name__}-{self.utility.__class__.__name__}-{self.sampler.__class__.__name__}-{self.is_done}",
indices=data.indices,
data_names=data.data_names,
)
ensure_backend_has_generator_return()

parallel = Parallel(return_as="generator_unordered")

self.utility.training_data = data
self.labels = unique_labels(data.y)

with make_parallel_flag() as flag:
# FIXME, DUH: this loop needs to be in the sampler or we will never converge
for label in self.labels:
sampler = self.sampler.for_label(label)
strategy = sampler.make_strategy(self.utility)
processor = delayed(strategy.process)

sample_generator = self.sampler.from_data(data)
strategy = self.sampler.make_strategy(self.utility)
processor = delayed(strategy.process)

with Parallel(return_as="generator_unordered") as parallel:
with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in sampler.generate_batches(data.indices)
for batch in sample_generator
)
for evaluation in Progress(delayed_evals, self.is_done):
self.result.update(evaluation.idx, evaluation.update)

for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
self.result.update(evaluation.idx, evaluation.update)
if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
flag.set()
break

if self.normalize_values:
self._normalize()

return self

def _normalize(self) -> ValuationResult:
r"""
Normalize a valuation result specific to classwise Shapley.
Expand All @@ -211,23 +181,27 @@ def _normalize(self) -> ValuationResult:
Returns:
Normalized ValuationResult object.
"""
u = self.utility
if self.result is None:
raise ValueError("You must call fit before calling _normalize()")

assert self.result is not None
assert self.labels is not None
assert u.training_data is not None
if self.utility.training_data is None:
raise ValueError("You should call fit before calling _normalize()")

logger.info("Normalizing valuation result.")
u.model.fit(u.training_data.x, u.training_data.y)
unique_labels = get_unique_labels(self.utility.training_data.y)
self.utility.model.fit(
self.utility.training_data.x, self.utility.training_data.y
)

for idx_label, label in enumerate(self.labels):
self.scorer.label = label
active_elements = u.training_data.y == label
for idx_label, label in enumerate(unique_labels):
active_elements = self.utility.training_data.y == label
indices_label_set = np.where(active_elements)[0]
indices_label_set = u.training_data.indices[indices_label_set]
indices_label_set = self.utility.training_data.indices[indices_label_set]

self.scorer.label = label
in_class_acc, _ = self.scorer.compute_in_and_out_of_class_scores(u.model)
in_class_acc, _ = self.scorer.compute_in_and_out_of_class_scores(
self.utility.model
)

sigma = np.sum(self.result.values[indices_label_set])
if sigma != 0:
Expand Down
2 changes: 2 additions & 0 deletions src/pydvl/valuation/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def fit(self, data: Dataset):
"""
from typing import Union

from .base import *
from .classwise import *
from .msr import *
from .permutation import *
from .powerset import *
Expand Down
5 changes: 2 additions & 3 deletions src/pydvl/valuation/samplers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def generate_batches(self, indices: IndexSetT) -> BatchGenerator:

# create an empty generator if the indices are empty. `generate_batches` is
# a generator function because it has a yield statement later in its body.
# Inside generator functionn, `return` acts like a `break`, which produces an
# Inside generator function, `return` acts like a `break`, which produces an
# empty generator function. See: https://stackoverflow.com/a/13243870
if len(indices) == 0:
return
Expand All @@ -125,8 +125,7 @@ def generate_batches(self, indices: IndexSetT) -> BatchGenerator:
self._n_samples = 0
for batch in chunked(self._generate(indices), self.batch_size):
yield batch
# FIXME, BUG: this could be wrong if the batch is not full. Just use lists
self._n_samples += self.batch_size
self._n_samples += len(batch)
if self._interrupted:
break

Expand Down
Loading

0 comments on commit 85025c0

Please sign in to comment.