Skip to content

Commit

Permalink
SDPA integration for nvFuser (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 authored Aug 20, 2024
1 parent bc8a5fe commit 3548ba8
Show file tree
Hide file tree
Showing 5 changed files with 560 additions and 232 deletions.
250 changes: 248 additions & 2 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import thunder.core.dtypes as dtypes
import thunder.torch as ltorch
from thunder.torch import TensorLike

from thunder.core import prims, utils
from thunder.core.baseutils import BoundSymbolInterface
from thunder.core.prims import PrimIDs
Expand All @@ -34,16 +36,29 @@
from thunder.core.utils import OrderedSet, check, check_same_dtype
from thunder.core.trace import TraceCtx, from_trace, TraceProvenance
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, Symbol, has_tags
from thunder.core.devices import Device, DeviceType
from thunder.core.devices import Device, DeviceType, cpu
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
from thunder.core.transform_common import dce, cse_single_bsym, replace_redundant_inputs, NON_FUNCTIONAL_OPS
from thunder.core.profile import add_markers
from thunder.core.compile_data import get_compile_option

from thunder.executors.utils import Region
from thunder.core.transforms import (
get_grad,
put_grads,
)

from thunder.executors.utils import (
Region,
_input_dtype_check_fused_scaled_dot_product_attention,
_input_shape_check_fused_scaled_dot_product_attention,
_fused_sdp_choice,
SpdaBackend,
)

from thunder.executors.passes import update_fusion_call_ctx
from thunder.extend import FUEL_LEVEL, FusionExecutor, register_executor, add_default_executor
from thunder.executors.nvfuserex import nvfuser_version

# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
# by nvfuserex.py when nvFuser is available.
Expand Down Expand Up @@ -2208,3 +2223,234 @@ def matmul(


register_supported(PrimIDs.MATMUL, matmul, _matmul_check)


# Registering SDPA operators for nvFuser
# SDPA requires an execution and grad transform since the forward and backward passes are called through different implementations.
# For both execution and grad transform, a new operator is registered with nvfuserex (ex.register_operator) and then added to the translation map (register_supported).
# The operators are tagged with OpTag.RANDOM_OP to prevent rematerialization in backward pass.
# Finally, the complete rule is registered through ex.register_supported, with the execution and grad transform wrapping around these operators.


# SDPA Forward
def _scaled_dot_product_flash_attention_forward_meta(
query: TensorLike,
key: TensorLike,
value: TensorLike,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
) -> tuple[TensorProxy, TensorProxy, int, int]:
# Reference metadata:
# * query (batch_size, num_heads, query_seq_len, E)
# * key (batch_size, num_heads, key_seq_len, E)
# * value (batch_size, num_heads, key_seq_len, Ev)
# * output (batch_size, num_heads, query_seq_len, Ev)

# at::_scaled_dot_product_flash_attention returns {output, log_sumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask}.
# In nvFuser, we only save {output, log_sumexp, philox_seed/offset} for backward since the other variables are not required for non-nested input tensors.
# For non-nested tensor, cum_seq_q/k is undefined, max_q/k can be inferred from input size, and we set `return_debug_mask=False`, so `debug_attn_mask` is a 1D zero tensor.

batch_size, num_heads, query_seq_len, E = query.shape
key_seq_len = key.shape[2]

return (
output := TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E)),
log_sumexp := TensorProxy(
shape=(batch_size, num_heads, query_seq_len), dtype=dtypes.float32, device=query.device, requires_grad=False
),
philox_seed := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False),
philox_offset := TensorProxy(shape=(), dtype=dtypes.int64, device=cpu, requires_grad=False),
)


def _scaled_dot_product_flash_attention_forward(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> Any:

inputs = [query, key, value, dropout_p, is_causal, scale]
nv_inputs = []
for inp in inputs:
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
nv_inputs.append(nv_inp)

return fd.ops.sdpfa_fwd(*nv_inputs)


nv_sdpfa_fwd = ex.register_operator(
"nv_sdpfa_fwd",
meta=_scaled_dot_product_flash_attention_forward_meta,
fn=_scaled_dot_product_flash_attention_forward,
tags=[prims.OpTags.RANDOM_OP],
)

register_supported(nv_sdpfa_fwd.id, _scaled_dot_product_flash_attention_forward, None)


# SDPA Backward
def _scaled_dot_product_flash_attention_backward_meta(
grad_out: TensorLike,
query: TensorLike,
key: TensorLike,
value: TensorLike,
out: TensorLike,
logsumexp: TensorLike,
dropout_p: float,
is_causal: bool,
philox_seed: TensorLike,
philox_offset: TensorLike,
*,
scale: None | float = None,
) -> tuple[TensorProxy, TensorProxy, TensorProxy]:

batch_size, num_heads, query_seq_len, E = query.shape
key_seq_len = key.shape[2]

# Reference metadata:
# https://github.com/pytorch/pytorch/blob/f57b00704e498a676854a02974ca9e0c42188b23/torch/_meta_registrations.py#L5043-L5063
grad_query = TensorProxy(like=query, shape=(batch_size, num_heads, query_seq_len, E))
grad_key = TensorProxy(like=key, shape=(batch_size, num_heads, key_seq_len, E))
grad_value = TensorProxy(like=value, shape=(batch_size, num_heads, key_seq_len, E))
return (grad_query, grad_key, grad_value)


def _scaled_dot_product_flash_attention_backward(
grad_out: TensorProxy,
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
out: TensorProxy,
logsumexp: TensorProxy,
dropout_p: float,
is_causal: bool,
philox_seed: TensorProxy,
philox_offset: TensorProxy,
*,
scale: None | float = None,
fd: FusionDefinition,
lc_to_nv_map: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

inputs = [grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, philox_seed, philox_offset, scale]
nv_inputs = []
for inp in inputs:
nv_inp = getnv(inp, fd, lc_to_nv_map) if inp is not None else None
nv_inputs.append(nv_inp)

return fd.ops.sdpfa_bwd(*nv_inputs)


nv_sdpfa_bwd = ex.register_operator(
"nv_sdpfa_bwd",
meta=_scaled_dot_product_flash_attention_backward_meta,
fn=_scaled_dot_product_flash_attention_backward,
tags=[prims.OpTags.RANDOM_OP],
)

register_supported(nv_sdpfa_bwd.id, _scaled_dot_product_flash_attention_backward, None)


# Checker for SDPA
def _scaled_dot_product_flash_attention_check(
query: Proxy,
key: Proxy,
value: Proxy,
attn_mask: Proxy | None,
dropout_p: float,
is_causal: bool,
*,
scale: None | float = None,
) -> bool:

# fd.ops.sdpfa_fwd and fd.ops.sdpfa_bwd are adding in versions 0.2.9 and 0.2.10 respectively.
if nvfuser_version() < LooseVersion("0.2.10"):
return False

enable_sdpa: None | bool = get_compile_option("nv_enable_sdpa", "Enable nvFuser flash attention SDPA.")

if not enable_sdpa:
return False

# Flash attn does not support attn_mask currently.
if attn_mask is not None:
return False

if not are_supported_tensors(query, key, value):
return False

# FP64 is not supported by flash attention
supported_dtypes = (dtypes.float16, dtypes.bfloat16)
_input_dtype_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None, supported_dtypes)
_input_shape_check_fused_scaled_dot_product_attention(query, key, value, attn_mask := None)

# nvFuser only implements flash attention currently.
backend = _fused_sdp_choice(query, key, value, None, dropout_p, is_causal, scale)
return backend == SpdaBackend.FLASH_ATTENTION


# SDPA execution_transform -- calls nv_sdpfa_fwd operator registered above
def scaled_dot_product_flash_attention(
query: TensorProxy,
key: TensorProxy,
value: TensorProxy,
attn_mask: TensorProxy = None,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
):
(attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd(
query, key, value, dropout_p, is_causal, scale=scale
)
return attn_output


# SDPA grad_transform -- calls nv_sdpfa_fwd and nv_sdpfa_bwd registered above
def scaled_dot_product_flash_attention_grad(
query: Proxy,
key: Proxy,
value: Proxy,
attn_mask: None | Proxy,
dropout_p: float = 0.0,
is_causal: bool = False,
*,
scale: None | float = None,
):

(attn_output, logsumexp, philox_seed, philox_offset) = nv_sdpfa_fwd(
query, key, value, dropout_p, is_causal, scale=scale
)
grad_out = get_grad(attn_output)
grad_query, grad_key, grad_val = nv_sdpfa_bwd(
grad_out,
query,
key,
value,
attn_output,
logsumexp,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale=scale,
)
put_grads((query, key, value), (grad_query, grad_key, grad_val))
return attn_output


# Register the complete rule for SDPA in nvfuser executor
ex.register_supported(
ltorch.scaled_dot_product_attention,
checker=_scaled_dot_product_flash_attention_check,
execution_transform=scaled_dot_product_flash_attention,
grad_transform=scaled_dot_product_flash_attention_grad,
)
Loading

0 comments on commit 3548ba8

Please sign in to comment.