Skip to content

Commit

Permalink
Fix comments and typos.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus Semmler committed Sep 1, 2023
1 parent c28ce5e commit 6dee93d
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 46 deletions.
28 changes: 18 additions & 10 deletions src/pydvl/utils/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,43 @@

import inspect
from functools import partial
from typing import Callable, Set, Tuple
from typing import Callable, Set, Tuple, Union

__all__ = ["unroll_partial_fn_args"]
__all__ = ["get_free_args_fn"]


def unroll_partial_fn_args(fun: Callable) -> Set[str]:
def get_free_args_fn(fun: Union[Callable, partial]) -> Set[str]:
"""
Unroll a function that was set by functools.partial.
Accept a function or partial definition and return the set of arguments that are
free. An argument is free if it is not set by the partial and is a parameter of the
function.
Args:
fun: Either or a function to unroll.
fun: A partial or a function to unroll.
Returns:
A tuple of the unrolled function and a set of the parameters that were set by
functools.partial.
A set of arguments that were set by the partial.
"""
args_set_by_partial: Set[str] = set()

def _rec_unroll_partial_function(g: Callable):
def _rec_unroll_partial_function(g: Union[Callable, partial]) -> Callable:
"""
Store arguments and recursively call itself if the function is a partial. In the
end, return the original function.
end, return the initial wrapped function.
Args:
g: A partial or a function to unroll.
Returns:
Initial wrapped function.
"""
nonlocal args_set_by_partial

if isinstance(g, partial):
args_set_by_partial.update(g.keywords.keys())
args_set_by_partial.update(g.args)
return _rec_unroll_partial_function(g.keywords["fun"])
inner_fn = g.keywords["fn"] if "fn" in g.keywords else g.func
return _rec_unroll_partial_function(inner_fn)
else:
return g

Expand Down
2 changes: 0 additions & 2 deletions src/pydvl/utils/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def random_powerset(
ValueError: if the element sampling probability is not in [0,1]
"""
if not isinstance(s, np.ndarray):
raise TypeError("Set must be an NDArray")
if q < 0 or q > 1:
raise ValueError("Element sampling probability must be in [0,1]")

Expand Down
37 changes: 9 additions & 28 deletions src/pydvl/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import functools
from abc import ABCMeta
from copy import deepcopy
from typing import Any, Callable, Optional, Protocol, Tuple, TypeVar, Union, cast
from typing import Any, Callable, Optional, Protocol, TypeVar, Union, cast

from numpy.random import Generator, SeedSequence
from numpy.typing import NDArray

from pydvl.utils.functional import unroll_partial_fn_args
from pydvl.utils.functional import get_free_args_fn

__all__ = ["SupervisedModel", "MapFunction", "ReduceFunction", "NoPublicConstructor"]

Expand Down Expand Up @@ -46,13 +45,14 @@ def score(self, x: NDArray, y: NDArray) -> float:
pass


def call_fun_remove_arg(*args, fun: Callable, arg: str, **kwargs):
def call_fun_remove_arg(*args, fn: Callable, arg: str, **kwargs):
"""
Calls the given function with the given arguments, but removes the given argument.
Calls the given function with the given arguments. In the process it removes the
specified keyword argument from the keyword arguments.
Args:
args: Positional arguments to pass to the function.
fun: The function to call.
fn: The function to call.
arg: The name of the argument to remove.
kwargs: Keyword arguments to pass to the function.
Expand All @@ -64,7 +64,7 @@ def call_fun_remove_arg(*args, fun: Callable, arg: str, **kwargs):
except KeyError:
pass

return fun(*args, **kwargs)
return fn(*args, **kwargs)


def maybe_add_argument(fun: Callable, new_arg: str) -> Callable:
Expand All @@ -83,10 +83,10 @@ def maybe_add_argument(fun: Callable, new_arg: str) -> Callable:
Returns:
A new function accepting one more keyword argument.
"""
if new_arg in unroll_partial_fn_args(fun):
if new_arg in get_free_args_fn(fun):
return fun

return functools.partial(call_fun_remove_arg, fun=fun, arg=new_arg)
return functools.partial(call_fun_remove_arg, fn=fun, arg=new_arg)


class NoPublicConstructor(ABCMeta):
Expand Down Expand Up @@ -137,22 +137,3 @@ def ensure_seed_sequence(
return cast(SeedSequence, seed.bit_generator.seed_seq) # type: ignore
else:
return SeedSequence(seed)


def call_fn_multiple_seeds(
fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs
) -> Tuple:
"""
Execute a function multiple times with different seeds. It copies the arguments
and keyword arguments before passing them to the function.
Args:
fn: The function to execute.
args: The arguments to pass to the function.
seeds: The seeds to use.
kwargs: The keyword arguments to pass to the function.
Returns:
A tuple of the results of the function.
"""
return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds)
3 changes: 1 addition & 2 deletions src/pydvl/value/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,8 +695,7 @@ def from_random(
("efficiency" property of Shapley values).
kwargs: Additional options to pass to the constructor of
[ValuationResult][pydvl.value.result.ValuationResult]. Use to override status, names, etc.
seed: ither an instance of a numpy random number generator or a seed for
it.
Returns:
A valuation result with its status set to
[Status.Converged][pydvl.utils.status.Status] by default.
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/value/shapley/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def compute_shapley_values(
others require specific subtypes.
n_jobs: Number of parallel jobs (available only to some methods)
seed: Either an instance of a numpy random number generator or a seed
for it.
for it.
mode: Choose which shapley algorithm to use. See
[ShapleyMode][pydvl.value.shapley.ShapleyMode] for a list of allowed value.
Expand Down
4 changes: 2 additions & 2 deletions src/pydvl/value/shapley/gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _group_testing_shapley(
progress: Whether to display progress bars for each job.
job_id: id to use for reporting progress (e.g. to place progres bars)
seed: Either an instance of a numpy random number generator or a seed
for it.
for it.
Returns:
"""
Expand Down Expand Up @@ -203,7 +203,7 @@ def group_testing_shapley(
address, number of cpus, etc.
progress: Whether to display progress bars for each job.
seed: Either an instance of a numpy random number generator or a seed
for it.
for it.
options: Additional options to pass to
[cvxpy.Problem.solve()](https://www.cvxpy.org/tutorial/advanced/index.html#solve-method-options).
E.g. to change the solver (which defaults to `cvxpy.SCS`) pass
Expand Down
3 changes: 2 additions & 1 deletion tests/value/shapley/test_montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
)
from pydvl.utils.numeric import num_samples_permutation_hoeffding
from pydvl.utils.score import Scorer, squashed_r2
from pydvl.utils.types import Seed, call_fn_multiple_seeds
from pydvl.utils.types import Seed
from pydvl.value import compute_shapley_values
from pydvl.value.shapley import ShapleyMode
from pydvl.value.shapley.naive import combinatorial_exact_shapley
from pydvl.value.stopping import MaxChecks, MaxUpdates

from .. import check_rank_correlation, check_total_value, check_values
from ..conftest import polynomial_dataset
from ..utils import call_fn_multiple_seeds

log = logging.getLogger(__name__)

Expand Down
25 changes: 25 additions & 0 deletions tests/value/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from copy import deepcopy
from typing import Callable, Tuple

from pydvl.utils.types import Seed


def call_fn_multiple_seeds(
fn: Callable, *args, seeds: Tuple[Seed, ...], **kwargs
) -> Tuple:
"""
Execute a function multiple times with different seeds. It copies the arguments
and keyword arguments before passing them to the function.
Args:
fn: The function to execute.
args: The arguments to pass to the function.
seeds: The seeds to use.
kwargs: The keyword arguments to pass to the function.
Returns:
A tuple of the results of the function.
"""
return tuple(fn(*deepcopy(args), **deepcopy(kwargs), seed=seed) for seed in seeds)

0 comments on commit 6dee93d

Please sign in to comment.