diff --git a/src/pydvl/influence/array.py b/src/pydvl/influence/array.py index c60e11693..7e71050f9 100644 --- a/src/pydvl/influence/array.py +++ b/src/pydvl/influence/array.py @@ -6,6 +6,8 @@ (chunked in one resp. two dimensions), with support for efficient storage and retrieval using the Zarr library. """ +from __future__ import annotations + import logging from abc import ABC, abstractmethod from typing import ( @@ -47,11 +49,10 @@ class SequenceAggregator(Generic[TensorType], ABC): @abstractmethod def __call__( self, - tensor_generator: Generator[TensorType, None, None], - len_generator: Optional[int] = None, + tensor_sequence: LazyChunkSequence, ): """ - Aggregates tensors from a generator. + Aggregates tensors from a sequence. Implement this method to define how a sequence of tensors, provided by a generator, should be combined. @@ -61,28 +62,26 @@ def __call__( class ListAggregator(SequenceAggregator): def __call__( self, - tensor_generator: Generator[TensorType, None, None], - len_generator: Optional[int] = None, + tensor_sequence: LazyChunkSequence, ) -> List[TensorType]: """ Aggregates tensors from a single-level generator into a list. This method simply collects each tensor emitted by the generator into a single list. Args: - tensor_generator: A generator that yields TensorType objects. - len_generator: if the number of elements from the generator is - known, this optional parameter can be used to improve logging - by adding a progressbar. + tensor_sequence: Object wrapping a generator that yields `TensorType` + objects. Returns: A list containing all the tensors provided by the tensor_generator. """ - gen = cast(Iterator[TensorType], tensor_generator) + gen = cast(Iterator[TensorType], tensor_sequence.generator_factory()) - if len_generator is not None: + if tensor_sequence.len_generator is not None: gen = cast( - Iterator[TensorType], tqdm(gen, total=len_generator, desc="Blocks") + Iterator[TensorType], + tqdm(gen, total=tensor_sequence.len_generator, desc="Blocks"), ) return [t for t in gen] @@ -90,15 +89,9 @@ def __call__( class NestedSequenceAggregator(Generic[TensorType], ABC): @abstractmethod - def __call__( - self, - nested_generators_of_tensors: Generator[ - Generator[TensorType, None, None], None, None - ], - len_outer_generator: Optional[int] = None, - ): + def __call__(self, nested_sequence_of_tensors: NestedLazyChunkSequence): """ - Aggregates tensors from a generator of generators. + Aggregates tensors from a nested sequence of tensors. Implement this method to specify how tensors, nested in two layers of generators, should be combined. Useful for complex data structures where tensors @@ -109,10 +102,7 @@ def __call__( class NestedListAggregator(NestedSequenceAggregator): def __call__( self, - nested_generators_of_tensors: Generator[ - Generator[TensorType, None, None], None, None - ], - len_outer_generator: Optional[int] = None, + nested_sequence_of_tensors: NestedLazyChunkSequence, ) -> List[List[TensorType]]: """ Aggregates tensors from a nested generator structure into a list of lists. @@ -120,22 +110,22 @@ def __call__( list structure. Args: - nested_generators_of_tensors: A generator of generators, where each inner - generator yields TensorType objects. - len_outer_generator: if the number of elements from the outer generator is - known from the context, this optional parameter can be used to improve - logging by adding a progressbar. + nested_sequence_of_tensors: Object wrapping a generator of generators, + where each inner generator yields TensorType objects. Returns: A list of lists, where each inner list contains tensors returned from one of the inner generators. """ - outer_gen = cast(Iterator[Iterator[TensorType]], nested_generators_of_tensors) - - if len_outer_generator is not None: + outer_gen = cast( + Iterator[Iterator[TensorType]], + nested_sequence_of_tensors.generator_factory(), + ) + len_outer_gen = nested_sequence_of_tensors.len_outer_generator + if len_outer_gen is not None: outer_gen = cast( Iterator[Iterator[TensorType]], - tqdm(outer_gen, total=len_outer_generator, desc="Row blocks"), + tqdm(outer_gen, total=len_outer_gen, desc="Row blocks"), ) return [list(tensor_gen) for tensor_gen in outer_gen] @@ -186,7 +176,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None): """ if aggregator is None: aggregator = ListAggregator() - return aggregator(self.generator_factory(), len_generator=self.len_generator) + return aggregator(self) @log_duration(log_level=logging.INFO) def to_zarr( @@ -306,9 +296,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None): """ if aggregator is None: aggregator = NestedListAggregator() - return aggregator( - self.generator_factory(), len_outer_generator=self.len_outer_generator - ) + return aggregator(self) @log_duration(log_level=logging.INFO) def to_zarr( diff --git a/src/pydvl/influence/torch/util.py b/src/pydvl/influence/torch/util.py index 581894af2..17813421b 100644 --- a/src/pydvl/influence/torch/util.py +++ b/src/pydvl/influence/torch/util.py @@ -25,7 +25,13 @@ from torch.utils.data import Dataset from tqdm import tqdm -from ..array import NestedSequenceAggregator, NumpyConverter, SequenceAggregator +from ..array import ( + LazyChunkSequence, + NestedLazyChunkSequence, + NestedSequenceAggregator, + NumpyConverter, + SequenceAggregator, +) logger = logging.getLogger(__name__) @@ -405,8 +411,7 @@ class TorchCatAggregator(SequenceAggregator[torch.Tensor]): def __call__( self, - tensor_generator: Generator[torch.Tensor, None, None], - len_generator: Optional[int] = None, + tensor_sequence: LazyChunkSequence[torch.Tensor], ): """ Aggregates tensors from a single-level generator into a single tensor by @@ -414,17 +419,15 @@ def __call__( of tensors into one larger tensor. Args: - tensor_generator: A generator that yields `torch.Tensor` objects. - len_generator: if the number of elements from the generator is - known, this optional parameter can be used to improve logging - by adding a progressbar. + tensor_sequence: Object wrapping a generator that yields `torch.Tensor` + objects. Returns: A single tensor formed by concatenating all tensors from the generator. The concatenation is performed along the default dimension (0). """ - t_gen = cast(Iterator[torch.Tensor], tensor_generator) - + t_gen = cast(Iterator[torch.Tensor], tensor_sequence.generator_factory()) + len_generator = tensor_sequence.len_generator if len_generator is not None: t_gen = cast( Iterator[torch.Tensor], tqdm(t_gen, total=len_generator, desc="Blocks") @@ -440,11 +443,7 @@ class NestedTorchCatAggregator(NestedSequenceAggregator[torch.Tensor]): """ def __call__( - self, - nested_generators_of_tensors: Generator[ - Generator[torch.Tensor, None, None], None, None - ], - len_outer_generator: Optional[int] = None, + self, nested_sequence_of_tensors: NestedLazyChunkSequence[torch.Tensor] ): """ Aggregates tensors from a nested generator structure into a single tensor by @@ -453,11 +452,8 @@ def __call__( form the final tensor. Args: - nested_generators_of_tensors: A generator of generators, where each inner - generator yields `torch.Tensor` objects. - len_outer_generator: if the number of elements from the outer generator is - known from the context, this optional parameter can be used to improve - logging by adding a progressbar. + nested_sequence_of_tensors: Object wrapping a generator of generators, + where each inner generator yields `torch.Tensor` objects. Returns: A single tensor formed by concatenating all tensors from the nested @@ -465,8 +461,11 @@ def __call__( """ - outer_gen = cast(Iterator[Iterator[torch.Tensor]], nested_generators_of_tensors) - + outer_gen = cast( + Iterator[Iterator[torch.Tensor]], + nested_sequence_of_tensors.generator_factory(), + ) + len_outer_generator = nested_sequence_of_tensors.len_outer_generator if len_outer_generator is not None: outer_gen = cast( Iterator[Iterator[torch.Tensor]],