Skip to content

Commit

Permalink
Refactor aggregators call interface to take sequence objects instead …
Browse files Browse the repository at this point in the history
…of generators and optional int paramter, adapt docstrings
  • Loading branch information
schroedk committed May 2, 2024
1 parent 4c95713 commit 0c81e53
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 58 deletions.
62 changes: 25 additions & 37 deletions src/pydvl/influence/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
(chunked in one resp. two dimensions), with support for efficient storage and retrieval
using the Zarr library.
"""
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import (
Expand Down Expand Up @@ -47,11 +49,10 @@ class SequenceAggregator(Generic[TensorType], ABC):
@abstractmethod
def __call__(
self,
tensor_generator: Generator[TensorType, None, None],
len_generator: Optional[int] = None,
tensor_sequence: LazyChunkSequence,
):
"""
Aggregates tensors from a generator.
Aggregates tensors from a sequence.
Implement this method to define how a sequence of tensors, provided by a
generator, should be combined.
Expand All @@ -61,44 +62,36 @@ def __call__(
class ListAggregator(SequenceAggregator):
def __call__(
self,
tensor_generator: Generator[TensorType, None, None],
len_generator: Optional[int] = None,
tensor_sequence: LazyChunkSequence,
) -> List[TensorType]:
"""
Aggregates tensors from a single-level generator into a list. This method simply
collects each tensor emitted by the generator into a single list.
Args:
tensor_generator: A generator that yields TensorType objects.
len_generator: if the number of elements from the generator is
known, this optional parameter can be used to improve logging
by adding a progressbar.
tensor_sequence: Object wrapping a generator that yields `TensorType`
objects.
Returns:
A list containing all the tensors provided by the tensor_generator.
"""

gen = cast(Iterator[TensorType], tensor_generator)
gen = cast(Iterator[TensorType], tensor_sequence.generator_factory())

if len_generator is not None:
if tensor_sequence.len_generator is not None:
gen = cast(
Iterator[TensorType], tqdm(gen, total=len_generator, desc="Blocks")
Iterator[TensorType],
tqdm(gen, total=tensor_sequence.len_generator, desc="Blocks"),
)

return [t for t in gen]


class NestedSequenceAggregator(Generic[TensorType], ABC):
@abstractmethod
def __call__(
self,
nested_generators_of_tensors: Generator[
Generator[TensorType, None, None], None, None
],
len_outer_generator: Optional[int] = None,
):
def __call__(self, nested_sequence_of_tensors: NestedLazyChunkSequence):
"""
Aggregates tensors from a generator of generators.
Aggregates tensors from a nested sequence of tensors.
Implement this method to specify how tensors, nested in two layers of
generators, should be combined. Useful for complex data structures where tensors
Expand All @@ -109,33 +102,30 @@ def __call__(
class NestedListAggregator(NestedSequenceAggregator):
def __call__(
self,
nested_generators_of_tensors: Generator[
Generator[TensorType, None, None], None, None
],
len_outer_generator: Optional[int] = None,
nested_sequence_of_tensors: NestedLazyChunkSequence,
) -> List[List[TensorType]]:
"""
Aggregates tensors from a nested generator structure into a list of lists.
Each inner generator is converted into a list of tensors, resulting in a nested
list structure.
Args:
nested_generators_of_tensors: A generator of generators, where each inner
generator yields TensorType objects.
len_outer_generator: if the number of elements from the outer generator is
known from the context, this optional parameter can be used to improve
logging by adding a progressbar.
nested_sequence_of_tensors: Object wrapping a generator of generators,
where each inner generator yields TensorType objects.
Returns:
A list of lists, where each inner list contains tensors returned from one
of the inner generators.
"""
outer_gen = cast(Iterator[Iterator[TensorType]], nested_generators_of_tensors)

if len_outer_generator is not None:
outer_gen = cast(
Iterator[Iterator[TensorType]],
nested_sequence_of_tensors.generator_factory(),
)
len_outer_gen = nested_sequence_of_tensors.len_outer_generator
if len_outer_gen is not None:
outer_gen = cast(
Iterator[Iterator[TensorType]],
tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"),
tqdm(outer_gen, total=len_outer_gen, desc="Row blocks"),
)

return [list(tensor_gen) for tensor_gen in outer_gen]
Expand Down Expand Up @@ -186,7 +176,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None):
"""
if aggregator is None:
aggregator = ListAggregator()
return aggregator(self.generator_factory(), len_generator=self.len_generator)
return aggregator(self)

@log_duration(log_level=logging.INFO)
def to_zarr(
Expand Down Expand Up @@ -306,9 +296,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None):
"""
if aggregator is None:
aggregator = NestedListAggregator()
return aggregator(
self.generator_factory(), len_outer_generator=self.len_outer_generator
)
return aggregator(self)

@log_duration(log_level=logging.INFO)
def to_zarr(
Expand Down
41 changes: 20 additions & 21 deletions src/pydvl/influence/torch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from torch.utils.data import Dataset
from tqdm import tqdm

from ..array import NestedSequenceAggregator, NumpyConverter, SequenceAggregator
from ..array import (
LazyChunkSequence,
NestedLazyChunkSequence,
NestedSequenceAggregator,
NumpyConverter,
SequenceAggregator,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -405,26 +411,23 @@ class TorchCatAggregator(SequenceAggregator[torch.Tensor]):

def __call__(
self,
tensor_generator: Generator[torch.Tensor, None, None],
len_generator: Optional[int] = None,
tensor_sequence: LazyChunkSequence[torch.Tensor],
):
"""
Aggregates tensors from a single-level generator into a single tensor by
concatenating them. This method is a straightforward way to combine a sequence
of tensors into one larger tensor.
Args:
tensor_generator: A generator that yields `torch.Tensor` objects.
len_generator: if the number of elements from the generator is
known, this optional parameter can be used to improve logging
by adding a progressbar.
tensor_sequence: Object wrapping a generator that yields `torch.Tensor`
objects.
Returns:
A single tensor formed by concatenating all tensors from the generator.
The concatenation is performed along the default dimension (0).
"""
t_gen = cast(Iterator[torch.Tensor], tensor_generator)

t_gen = cast(Iterator[torch.Tensor], tensor_sequence.generator_factory())
len_generator = tensor_sequence.len_generator
if len_generator is not None:
t_gen = cast(
Iterator[torch.Tensor], tqdm(t_gen, total=len_generator, desc="Blocks")
Expand All @@ -440,11 +443,7 @@ class NestedTorchCatAggregator(NestedSequenceAggregator[torch.Tensor]):
"""

def __call__(
self,
nested_generators_of_tensors: Generator[
Generator[torch.Tensor, None, None], None, None
],
len_outer_generator: Optional[int] = None,
self, nested_sequence_of_tensors: NestedLazyChunkSequence[torch.Tensor]
):
"""
Aggregates tensors from a nested generator structure into a single tensor by
Expand All @@ -453,20 +452,20 @@ def __call__(
form the final tensor.
Args:
nested_generators_of_tensors: A generator of generators, where each inner
generator yields `torch.Tensor` objects.
len_outer_generator: if the number of elements from the outer generator is
known from the context, this optional parameter can be used to improve
logging by adding a progressbar.
nested_sequence_of_tensors: Object wrapping a generator of generators,
where each inner generator yields `torch.Tensor` objects.
Returns:
A single tensor formed by concatenating all tensors from the nested
generators.
"""

outer_gen = cast(Iterator[Iterator[torch.Tensor]], nested_generators_of_tensors)

outer_gen = cast(
Iterator[Iterator[torch.Tensor]],
nested_sequence_of_tensors.generator_factory(),
)
len_outer_generator = nested_sequence_of_tensors.len_outer_generator
if len_outer_generator is not None:
outer_gen = cast(
Iterator[Iterator[torch.Tensor]],
Expand Down

0 comments on commit 0c81e53

Please sign in to comment.