From df28f5bb24fd9428a4775a431a8249029ad9aec0 Mon Sep 17 00:00:00 2001 From: Yuanshuo Cui Date: Mon, 12 Aug 2024 10:42:53 -0700 Subject: [PATCH] Fix async test set up (#1314) Summary: Pull Request resolved: https://github.com/pytorch/captum/pull/1314 Follow up of D56764316 to properly handle async test case. Reviewed By: cyrjano Differential Revision: D60151017 fbshipit-source-id: 52714e422d98cf28c4e46793ce42611bb1bdb8b9 --- captum/attr/_core/feature_ablation.py | 4 ++-- tests/attr/test_feature_ablation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index b94879ec9..32f11d6ae 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -831,12 +831,12 @@ def _accumulate_for_single_input( weight: List[Tensor], ) -> None: if total_attrib: - total_attrib[idx] += attrib + total_attrib[idx] = attrib[idx] else: total_attrib.extend(attrib) if self.use_weights: if weights: - weights[idx] += weight + weights[idx] = weight[idx] else: weights.extend(weight) diff --git a/tests/attr/test_feature_ablation.py b/tests/attr/test_feature_ablation.py index 9fb3af66d..9f05a1cb3 100644 --- a/tests/attr/test_feature_ablation.py +++ b/tests/attr/test_feature_ablation.py @@ -472,14 +472,14 @@ def forward_func(inp): abl = FeatureAblation(forward_func) abl.use_futures = True - inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True) + inp = torch.tensor([[20.0, 50.0, 30.0], [10.0, 40.0, 20.0]], requires_grad=True) self._ablation_test_assert( ablation_algo=abl, test_input=inp, baselines=None, target=0, perturbations_per_eval=(1,), - expected_ablation=torch.tensor([[80.0, 200.0, 120.0]]), + expected_ablation=torch.tensor([[80.0, 200.0, 120.0], [40.0, 160.0, 80.0]]), ) def test_unassociated_output_3d_tensor(self) -> None: