Skip to content

Commit

Permalink
Add log_duration decorator with level INFO to lazy object compute and…
Browse files Browse the repository at this point in the history
… to_zarr methods
  • Loading branch information
schroedk committed Apr 21, 2024
1 parent f09f020 commit 898cdc1
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/pydvl/influence/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
using the Zarr library.
"""

import logging
from abc import ABC, abstractmethod
from typing import Callable, Generator, Generic, List, Optional, Tuple, Union

import zarr
from numpy.typing import NDArray
from zarr.storage import StoreLike

from ..utils import log_duration
from .base_influence_function_model import TensorType


Expand Down Expand Up @@ -119,6 +121,7 @@ def __init__(
):
self.generator_factory = generator_factory

@log_duration(log_level=logging.INFO)
def compute(self, aggregator: Optional[SequenceAggregator] = None):
"""
Computes and optionally aggregates the chunks of the array using the provided
Expand All @@ -139,6 +142,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None):
aggregator = ListAggregator()
return aggregator(self.generator_factory())

@log_duration(log_level=logging.INFO)
def to_zarr(
self,
path_or_url: Union[str, StoreLike],
Expand Down Expand Up @@ -223,6 +227,7 @@ def __init__(
):
self.generator_factory = generator_factory

@log_duration(log_level=logging.INFO)
def compute(self, aggregator: Optional[NestedSequenceAggregator] = None):
"""
Computes and optionally aggregates the chunks of the array using the provided
Expand All @@ -244,6 +249,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None):
aggregator = NestedListAggregator()
return aggregator(self.generator_factory())

@log_duration(log_level=logging.INFO)
def to_zarr(
self,
path_or_url: Union[str, StoreLike],
Expand Down

0 comments on commit 898cdc1

Please sign in to comment.