Skip to content

Commit

Permalink
Fix test_eager_matches_sdpa_inference for XPU backend (huggingfac…
Browse files Browse the repository at this point in the history
…e#34889)

* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <[email protected]>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: huggingface#34888
Signed-off-by: Dmitry Rogozhkin <[email protected]>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <[email protected]>

---------

Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh authored Dec 2, 2024
1 parent f41d5d8 commit 3183047
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(

def forward(self, input: Tensor) -> Tensor:
args = _cast_if_autocast_enabled(input, self.normalized_shape, self.weight + 1, self.bias, self.eps)
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(input.device.type, enabled=False):
return F.layer_norm(*args)


Expand Down
10 changes: 8 additions & 2 deletions tests/models/mimi/test_modeling_mimi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
)

from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel


if is_torch_available():
Expand Down Expand Up @@ -636,7 +636,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -653,6 +653,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
20 changes: 16 additions & 4 deletions tests/models/musicgen/test_modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin


Expand Down Expand Up @@ -607,7 +607,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -629,6 +629,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1343,7 +1349,7 @@ def test_sdpa_can_dispatch_on_flash(self):
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)

@require_flash_attn
Expand Down Expand Up @@ -1669,7 +1675,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -1691,6 +1697,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
20 changes: 16 additions & 4 deletions tests/models/musicgen_melody/test_modeling_musicgen_melody.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, sdpa_kernel
from ...test_pipeline_mixin import PipelineTesterMixin


Expand Down Expand Up @@ -615,7 +615,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -637,6 +637,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down Expand Up @@ -1333,7 +1339,7 @@ def test_sdpa_can_dispatch_on_flash(self):
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
inputs_dict[name] = inp.to(torch.float16)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
_ = model(**inputs_dict)

@require_flash_attn
Expand Down Expand Up @@ -1632,7 +1638,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
# TODO: test gradients as well (& for FA2 as well!)
# Ignore copy
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -1654,6 +1660,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down
24 changes: 23 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config):
unset_hf_deepspeed_config()


def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient):
if version.parse(torch.__version__).release < version.parse("2.3").release:
return torch.backends.cuda.sdp_kernel(
enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
)

backends = []
if enable_flash:
backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION]
if enable_math:
backends += [torch.nn.attention.SDPBackend.MATH]
if enable_mem_efficient:
backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]
return torch.nn.attention.sdpa_kernel(backends)


@require_torch
class ModelTesterMixin:
model_tester = None
Expand Down Expand Up @@ -4175,7 +4191,7 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):

# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
Expand All @@ -4198,6 +4214,12 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4
Expand Down

0 comments on commit 3183047

Please sign in to comment.