Skip to content

Commit

Permalink
Fix missing move to model device for EkfacInfluence implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed May 2, 2024
1 parent 5327cac commit cafa32a
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,7 @@ def _get_kfac_blocks(
data, disable=not self.progress, desc="K-FAC blocks - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1319,7 +1319,7 @@ def _update_diag(
data, disable=not self.progress, desc="Update Diagonal - batch progress"
):
data_len += x.shape[0]
pred_y = self.model(x)
pred_y = self.model(x.to(self.model_device))
loss = empirical_cross_entropy_loss_fn(pred_y)
loss.backward()

Expand Down Expand Up @@ -1526,7 +1526,10 @@ def influences_from_factors_by_layer(
influences = {}
for layer_id, layer_z_test in z_test_factors.items():
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = layer_z_test @ total_grad[:, start_idx:end_idx].T
influences[layer_id] = (
layer_z_test.to(self.model_device)
@ total_grad[:, start_idx:end_idx].T
)
start_idx = end_idx
return influences
elif mode == InfluenceMode.Perturbation:
Expand All @@ -1539,7 +1542,7 @@ def influences_from_factors_by_layer(
end_idx = start_idx + layer_z_test.shape[1]
influences[layer_id] = torch.einsum(
"ia,j...a->ij...",
layer_z_test,
layer_z_test.to(self.model_device),
total_mixed_grad[:, start_idx:end_idx],
)
start_idx = end_idx
Expand Down Expand Up @@ -1626,7 +1629,7 @@ def explore_hessian_regularization(
being dictionaries containing the influences for each layer of the model,
with the layer name as key.
"""
grad = self._loss_grad(x, y)
grad = self._loss_grad(x.to(self.model_device), y.to(self.model_device))
influences_by_reg_value = {}
for reg_value in regularization_values:
reg_factors = self._solve_hvp_by_layer(
Expand Down

0 comments on commit cafa32a

Please sign in to comment.