Skip to content

Commit

Permalink
Refactor n_parameters of TorchComposableInfluence to be callable befo…
Browse files Browse the repository at this point in the history
…re fit
  • Loading branch information
schroedk committed Jun 13, 2024
1 parent 6e20ec5 commit ef32847
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,11 @@ def block_names(self) -> List[str]:

@property
def n_parameters(self):
return sum(block.op.input_size for _, block in self.block_mapper.items())
return sum(
param.numel()
for block in self.parameter_dict.values()
for param in block.values()
)

@abstractmethod
def with_regularization(
Expand Down

0 comments on commit ef32847

Please sign in to comment.