From cafa32a117a76cb47d57c70c9460505c93f75d3d Mon Sep 17 00:00:00 2001 From: Kristof Schroeder Date: Fri, 3 May 2024 01:28:19 +0200 Subject: [PATCH] Fix missing move to model device for EkfacInfluence implementation --- .../influence/torch/influence_function_model.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/pydvl/influence/torch/influence_function_model.py b/src/pydvl/influence/torch/influence_function_model.py index 46a5fa16e..4a6cb638c 100644 --- a/src/pydvl/influence/torch/influence_function_model.py +++ b/src/pydvl/influence/torch/influence_function_model.py @@ -1195,7 +1195,7 @@ def _get_kfac_blocks( data, disable=not self.progress, desc="K-FAC blocks - batch progress" ): data_len += x.shape[0] - pred_y = self.model(x) + pred_y = self.model(x.to(self.model_device)) loss = empirical_cross_entropy_loss_fn(pred_y) loss.backward() @@ -1319,7 +1319,7 @@ def _update_diag( data, disable=not self.progress, desc="Update Diagonal - batch progress" ): data_len += x.shape[0] - pred_y = self.model(x) + pred_y = self.model(x.to(self.model_device)) loss = empirical_cross_entropy_loss_fn(pred_y) loss.backward() @@ -1526,7 +1526,10 @@ def influences_from_factors_by_layer( influences = {} for layer_id, layer_z_test in z_test_factors.items(): end_idx = start_idx + layer_z_test.shape[1] - influences[layer_id] = layer_z_test @ total_grad[:, start_idx:end_idx].T + influences[layer_id] = ( + layer_z_test.to(self.model_device) + @ total_grad[:, start_idx:end_idx].T + ) start_idx = end_idx return influences elif mode == InfluenceMode.Perturbation: @@ -1539,7 +1542,7 @@ def influences_from_factors_by_layer( end_idx = start_idx + layer_z_test.shape[1] influences[layer_id] = torch.einsum( "ia,j...a->ij...", - layer_z_test, + layer_z_test.to(self.model_device), total_mixed_grad[:, start_idx:end_idx], ) start_idx = end_idx @@ -1626,7 +1629,7 @@ def explore_hessian_regularization( being dictionaries containing the influences for each layer of the model, with the layer name as key. """ - grad = self._loss_grad(x, y) + grad = self._loss_grad(x.to(self.model_device), y.to(self.model_device)) influences_by_reg_value = {} for reg_value in regularization_values: reg_factors = self._solve_hvp_by_layer(