From 11114ab8bee2b0ddcb9f4d5a9378a5da32428ff3 Mon Sep 17 00:00:00 2001 From: redwrasse Date: Wed, 22 Mar 2023 20:12:19 +0000 Subject: [PATCH] rename to need_attn_weights to match elsewhere (#97102) Change variable spelling from `need_atten_weights` to `need_attn_weights` to match naming convention elsewhere in pytorch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/97102 Approved by: https://github.com/drisspg --- .../src/ATen/native/nested/NestedTensorTransformerFunctions.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h index 0f623f896d0f3..cee721d7bc8f6 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -89,7 +89,7 @@ TORCH_API Tensor flash_attention_helper( const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool need_attn_weights, bool is_causal); TORCH_API std::tuple mem_efficient_helper_nested_unpacked( @@ -97,7 +97,7 @@ TORCH_API std::tuple mem_efficient_helper_nested_unpacked( const Tensor& key, const Tensor& value, double dropout_p, - bool need_atten_weights, + bool need_attn_weights, bool is_causal); } // namespace native } // namespace at