diff --git a/src/pydvl/influence/torch/base.py b/src/pydvl/influence/torch/base.py index f4ccbf580..8cfdda401 100644 --- a/src/pydvl/influence/torch/base.py +++ b/src/pydvl/influence/torch/base.py @@ -724,7 +724,11 @@ def block_names(self) -> List[str]: @property def n_parameters(self): - return sum(block.op.input_size for _, block in self.block_mapper.items()) + return sum( + param.numel() + for block in self.parameter_dict.values() + for param in block.values() + ) @abstractmethod def with_regularization(