Skip to content

Commit

Permalink
Fix sdpfa_fwd to not assume the presence of DIDs. (#3116)
Browse files Browse the repository at this point in the history
Similar to #3073, `sdpfa_fwd` shouldn't assume DIDs are available at
definition time. Instead, treat extra preceding dimensions as batch at
definition time and check they are device parallel at evaluation time.

This is required to land #3115.
  • Loading branch information
wujingyue authored Oct 7, 2024
1 parent 3476d34 commit 222315e
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 47 deletions.
28 changes: 14 additions & 14 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2138,16 +2138,16 @@ class MatmulOp : public Expr {
std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* out() const {
return output(0);
TensorView* out() const {
return output(0)->as<TensorView>();
}

Val* inA() const {
return input(0);
TensorView* inA() const {
return input(0)->as<TensorView>();
}

Val* inB() const {
return input(1);
TensorView* inB() const {
return input(1)->as<TensorView>();
}

std::vector<PolymorphicValue> evaluate(
Expand Down Expand Up @@ -2256,20 +2256,20 @@ class SdpaFwdOp : public Expr {
std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* attn_out() const {
return output(0);
TensorView* attn_out() const {
return output(0)->as<TensorView>();
}

Val* query() const {
return input(0);
TensorView* query() const {
return input(0)->as<TensorView>();
}

Val* key() const {
return input(1);
TensorView* key() const {
return input(1)->as<TensorView>();
}

Val* value() const {
return input(2);
TensorView* value() const {
return input(2)->as<TensorView>();
}

Val* dropout_p() const {
Expand Down
20 changes: 19 additions & 1 deletion csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4430,8 +4430,26 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
// https://github.com/NVIDIA/Fuser/issues/2563
bool handle_device_dim = false;
if (query.dim() == 5) {
NVF_CHECK(key.dim() == 5 && value.dim() == 5);
handle_device_dim = true;

NVF_CHECK(key.dim() == 5 && value.dim() == 5);

auto query_domain =
TensorDomain::noReductions(this->query()->getLogicalDomain());
auto key_domain =
TensorDomain::noReductions(this->key()->getLogicalDomain());
auto value_domain =
TensorDomain::noReductions(this->value()->getLogicalDomain());
NVF_CHECK(
query_domain.front()->isDeviceDim(),
"Only support DID parallelization on outermost axis");
NVF_CHECK(
key_domain.front()->isDeviceDim(),
"Only support DID parallelization on outermost axis");
NVF_CHECK(
value_domain.front()->isDeviceDim(),
"Only support DID parallelization on outermost axis");

query = query.squeeze(0);
key = key.squeeze(0);
value = value.squeeze(0);
Expand Down
42 changes: 14 additions & 28 deletions csrc/ops/composite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,35 +444,21 @@ SdpfaFwdResult sdpfa_fwd(
auto query_domain = TensorDomain::noReductions(query->getLogicalDomain());
auto key_domain = TensorDomain::noReductions(key->getLogicalDomain());
auto value_domain = TensorDomain::noReductions(value->getLogicalDomain());

// Temporary handling of DID parallelization see
// https://github.com/NVIDIA/Fuser/issues/2563
bool has_device_dim = (query_domain.size() == 5);
if (has_device_dim) {
NVF_CHECK(
query_domain[0]->isDeviceDim(),
"Only support DID parallelization on outermost axis");
NVF_CHECK(
key_domain[0]->isDeviceDim(),
"Only support DID parallelization on outermost axis");
NVF_CHECK(
value_domain[0]->isDeviceDim(),
"Only support DID parallelization on outermost axis");
}

auto concrete_query_size = TensorDomain::noDevices(query_domain).size();
auto concrete_key_size = TensorDomain::noDevices(key_domain).size();
auto concrete_value_size = TensorDomain::noDevices(value_domain).size();

NVF_CHECK(
concrete_query_size == 4 && concrete_key_size == 4 &&
concrete_value_size == 4,
"Expected query, key, and value to be 4D but got: ",
concrete_query_size,
" ",
concrete_key_size,
" ,and ",
concrete_value_size);
query_domain.size() == key_domain.size(),
query_domain,
" vs ",
key_domain);
NVF_CHECK(
query_domain.size() == value_domain.size(),
query_domain,
" vs ",
value_domain);
NVF_CHECK(
query_domain.size() == 4 || query_domain.size() == 5,
"Expect Q/K/V to be either 4D or 5D. If 5D, the first dimension is "
"expected to be device parallel during expression evaluation: ",
query_domain);

NVF_CHECK(
!dropout_p || dropout_p->isScalar(),
Expand Down
88 changes: 84 additions & 4 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import pytest
import torch
from torch.nn.attention import SDPBackend

import mpi_fixtures
import nvfuser
import utils
from nvfuser import DataType, FusionDefinition


Expand Down Expand Up @@ -55,8 +57,8 @@ def multidevice_schedule(self):
self.sched._set_device_mesh(self.t2, mesh)
self.sched.parallelize(self.t0, 0, nvfuser.ParallelType.mesh_x)

fn = Model()
outputs = fn.execute([sharded_input])
fd = Model()
outputs = fd.execute([sharded_input])
torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2)


Expand Down Expand Up @@ -98,8 +100,8 @@ def multidevice_schedule(self):
unsharded_bias_tensor = torch.randn(d * h, device="cuda")
bias_tensor = unsharded_bias_tensor.view([d, h])[rank : rank + 1]

fn = Model(d, b, s, h)
out_tensors = fn.execute([inp_tensor, weight_tensor, bias_tensor])
fd = Model(d, b, s, h)
out_tensors = fd.execute([inp_tensor, weight_tensor, bias_tensor])

# [b, s, d*h]
unsharded_out_tensor = torch.nn.functional.linear(
Expand All @@ -112,3 +114,81 @@ def multidevice_schedule(self):
torch.testing.assert_close(
out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-4
)


@pytest.mark.skipif(
utils.is_pre_ampere(),
reason="Flash Attention is only supported on Ampere and newer devices.",
)
@pytest.mark.mpi
def test_sdpa(mpi_test):
d, b, s, a, h = mpi_test.size, 2, 1024, 12, 768

if a % d != 0:
pytest.skip(f"We only support even split, so {a} has to be divisible by {d}.")

class Model(FusionDefinition):
def definition(self) -> None:
self.q = self.define_tensor(
shape=[d, -1, -1, -1, -1],
contiguity=[True if d > 1 else None, True, True, True, True],
dtype=DataType.BFloat16,
)
self.k = self.define_tensor(
shape=[d, -1, -1, -1, -1],
contiguity=[True if d > 1 else None, True, True, True, True],
dtype=DataType.BFloat16,
)
self.v = self.define_tensor(
shape=[d, -1, -1, -1, -1],
contiguity=[True if d > 1 else None, True, True, True, True],
dtype=DataType.BFloat16,
)
# TODO(#3123): support sharded dropout and change this to a
# positive probability.
dropout_p = self.define_scalar(0.0, dtype=DataType.Double)
is_causal = self.define_scalar(True, dtype=DataType.Bool)
sdpa_result = self.ops.sdpfa_fwd(
self.q, self.k, self.v, dropout_p, is_causal, scale=None
)
attn = sdpa_result[0]
self.add_output(attn)

def multidevice_schedule(self) -> None:
mesh = self.sched._create_device_mesh(range(d))
for t in [self.q, self.k, self.v]:
self.sched._set_device_mesh(t, mesh)
self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x)

torch.cuda.set_device(mpi_test.local_rank)
torch.manual_seed(0)
q, k, v = [
torch.randn(b, a, s, h // a, dtype=torch.bfloat16, device="cuda")
for _ in range(3)
]

with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
expected_attn = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
dropout_p=0.0,
is_causal=True,
scale=None,
)

rank = mpi_test.rank

# Head-parallelize Q, K, V or the attention output of an SDPA.
def head_parallelize(t: torch.Tensor) -> torch.Tensor:
assert t.shape == torch.Size([b, a, s, h // a])
return t.view([b, d, a // d, s, h // a]).transpose(0, 1)[rank : rank + 1]

fd = Model()
attn = fd.execute([head_parallelize(q), head_parallelize(k), head_parallelize(v)])[
0
]
# Use the default rtol for bfloat16 and a relaxed atol.
torch.testing.assert_close(
attn, head_parallelize(expected_attn), rtol=1.6e-2, atol=1e-3
)

0 comments on commit 222315e

Please sign in to comment.