From 4ec9edfbde6b640a2f7c9228876daff85e202ebd Mon Sep 17 00:00:00 2001 From: Can Balioglu Date: Fri, 3 Nov 2023 21:21:52 -0400 Subject: [PATCH] Fixes PyTorch 2.1 compatibility issues (#132) --- fairseq2n/CMakeLists.txt | 2 +- src/fairseq2/optim/optimizer_base.py | 7 ++--- tests/unit/nn/transformer/test_attention.py | 33 ++++++++------------- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/fairseq2n/CMakeLists.txt b/fairseq2n/CMakeLists.txt index a51fd25a7..547436256 100644 --- a/fairseq2n/CMakeLists.txt +++ b/fairseq2n/CMakeLists.txt @@ -164,7 +164,7 @@ if(FAIRSEQ2N_THREAD_LIB STREQUAL "tbb") find_package(TBB 2021.8 REQUIRED) endif() -find_package(Torch 1.12 REQUIRED) +find_package(Torch 1.13 REQUIRED) if(FAIRSEQ2N_BUILD_PYTHON_BINDINGS) find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) diff --git a/src/fairseq2/optim/optimizer_base.py b/src/fairseq2/optim/optimizer_base.py index 1185ab4cb..987fb78ca 100644 --- a/src/fairseq2/optim/optimizer_base.py +++ b/src/fairseq2/optim/optimizer_base.py @@ -10,14 +10,13 @@ import torch from torch.optim import Optimizer -from fairseq2.typing import finaloverride - class OptimizerBase(ABC, Optimizer): """Represents the base class for all optimizers.""" - @finaloverride - def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + def step( # type: ignore[override] + self, closure: Optional[Callable[[], float]] = None + ) -> Optional[float]: loss = None prev_grad = torch.is_grad_enabled() diff --git a/tests/unit/nn/transformer/test_attention.py b/tests/unit/nn/transformer/test_attention.py index cc32abbe5..8c0f089de 100644 --- a/tests/unit/nn/transformer/test_attention.py +++ b/tests/unit/nn/transformer/test_attention.py @@ -13,7 +13,7 @@ from fairseq2.nn.padding import PaddingMask from fairseq2.nn.transformer import CustomAttentionMask, NaiveSDPA, TorchSDPA from fairseq2.utils.version import is_pt2_or_greater -from tests.common import assert_close, device, tmp_rng_seed +from tests.common import assert_close, device class TestScaledDotProductAttention: @@ -21,26 +21,22 @@ class TestScaledDotProductAttention: not is_pt2_or_greater(), reason="requires PyTorch 2.0.0 or greater" ) # fmt: off - @pytest.mark.parametrize("use_key_padding_mask,use_attn_mask,attn_dropout_p,training", + @pytest.mark.parametrize("use_key_padding_mask,use_attn_mask,training", [ - (False, False, 0.0, True), - (True, True, 0.0, True), - (False, True, 0.5, True), - (True, False, 0.5, True), - (False, False, 0.5, False), - (False, True, 0.9, False), + (False, False, True), + (True, True, True), + (False, True, True), + (True, False, True), + (False, False, False), + (False, True, False), ], ) # fmt: on def test_torch_sdpa( - self, - use_key_padding_mask: bool, - use_attn_mask: bool, - attn_dropout_p: float, - training: bool, + self, use_key_padding_mask: bool, use_attn_mask: bool, training: bool ) -> None: - torch_sdpa = TorchSDPA(attn_dropout_p=attn_dropout_p) - naive_sdpa = NaiveSDPA(attn_dropout_p=attn_dropout_p) + torch_sdpa = TorchSDPA() + naive_sdpa = NaiveSDPA() if training: torch_sdpa.eval() @@ -48,11 +44,8 @@ def test_torch_sdpa( kwargs = self._get_sdpa_args(use_key_padding_mask, use_attn_mask) - with tmp_rng_seed(device): - attn1, _ = torch_sdpa(**kwargs) - - with tmp_rng_seed(device): - attn2, _ = naive_sdpa(**kwargs) + attn1, _ = torch_sdpa(**kwargs) + attn2, _ = naive_sdpa(**kwargs) assert_close(attn1, attn2)