diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index c25fdc37087..7d5f41a84ac 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2138,16 +2138,16 @@ class MatmulOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; - Val* out() const { - return output(0); + TensorView* out() const { + return output(0)->as(); } - Val* inA() const { - return input(0); + TensorView* inA() const { + return input(0)->as(); } - Val* inB() const { - return input(1); + TensorView* inB() const { + return input(1)->as(); } std::vector evaluate( @@ -2256,20 +2256,20 @@ class SdpaFwdOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; - Val* attn_out() const { - return output(0); + TensorView* attn_out() const { + return output(0)->as(); } - Val* query() const { - return input(0); + TensorView* query() const { + return input(0)->as(); } - Val* key() const { - return input(1); + TensorView* key() const { + return input(1)->as(); } - Val* value() const { - return input(2); + TensorView* value() const { + return input(2)->as(); } Val* dropout_p() const { diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 0d4c4f3bd42..24b4d7309ed 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -4430,8 +4430,26 @@ std::vector SdpaFwdOp::evaluate( // https://github.com/NVIDIA/Fuser/issues/2563 bool handle_device_dim = false; if (query.dim() == 5) { - NVF_CHECK(key.dim() == 5 && value.dim() == 5); handle_device_dim = true; + + NVF_CHECK(key.dim() == 5 && value.dim() == 5); + + auto query_domain = + TensorDomain::noReductions(this->query()->getLogicalDomain()); + auto key_domain = + TensorDomain::noReductions(this->key()->getLogicalDomain()); + auto value_domain = + TensorDomain::noReductions(this->value()->getLogicalDomain()); + NVF_CHECK( + query_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + key_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + NVF_CHECK( + value_domain.front()->isDeviceDim(), + "Only support DID parallelization on outermost axis"); + query = query.squeeze(0); key = key.squeeze(0); value = value.squeeze(0); diff --git a/csrc/ops/composite.cpp b/csrc/ops/composite.cpp index 400fa264b70..705766d3145 100644 --- a/csrc/ops/composite.cpp +++ b/csrc/ops/composite.cpp @@ -444,35 +444,21 @@ SdpfaFwdResult sdpfa_fwd( auto query_domain = TensorDomain::noReductions(query->getLogicalDomain()); auto key_domain = TensorDomain::noReductions(key->getLogicalDomain()); auto value_domain = TensorDomain::noReductions(value->getLogicalDomain()); - - // Temporary handling of DID parallelization see - // https://github.com/NVIDIA/Fuser/issues/2563 - bool has_device_dim = (query_domain.size() == 5); - if (has_device_dim) { - NVF_CHECK( - query_domain[0]->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - key_domain[0]->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - NVF_CHECK( - value_domain[0]->isDeviceDim(), - "Only support DID parallelization on outermost axis"); - } - - auto concrete_query_size = TensorDomain::noDevices(query_domain).size(); - auto concrete_key_size = TensorDomain::noDevices(key_domain).size(); - auto concrete_value_size = TensorDomain::noDevices(value_domain).size(); - NVF_CHECK( - concrete_query_size == 4 && concrete_key_size == 4 && - concrete_value_size == 4, - "Expected query, key, and value to be 4D but got: ", - concrete_query_size, - " ", - concrete_key_size, - " ,and ", - concrete_value_size); + query_domain.size() == key_domain.size(), + query_domain, + " vs ", + key_domain); + NVF_CHECK( + query_domain.size() == value_domain.size(), + query_domain, + " vs ", + value_domain); + NVF_CHECK( + query_domain.size() == 4 || query_domain.size() == 5, + "Expect Q/K/V to be either 4D or 5D. If 5D, the first dimension is " + "expected to be device parallel during expression evaluation: ", + query_domain); NVF_CHECK( !dropout_p || dropout_p->isScalar(), diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index d1e9b3f5485..a8905da236e 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -4,9 +4,11 @@ import pytest import torch +from torch.nn.attention import SDPBackend import mpi_fixtures import nvfuser +import utils from nvfuser import DataType, FusionDefinition @@ -55,8 +57,8 @@ def multidevice_schedule(self): self.sched._set_device_mesh(self.t2, mesh) self.sched.parallelize(self.t0, 0, nvfuser.ParallelType.mesh_x) - fn = Model() - outputs = fn.execute([sharded_input]) + fd = Model() + outputs = fd.execute([sharded_input]) torch.testing.assert_close(outputs[0], unsharded_input.relu() * 2) @@ -98,8 +100,8 @@ def multidevice_schedule(self): unsharded_bias_tensor = torch.randn(d * h, device="cuda") bias_tensor = unsharded_bias_tensor.view([d, h])[rank : rank + 1] - fn = Model(d, b, s, h) - out_tensors = fn.execute([inp_tensor, weight_tensor, bias_tensor]) + fd = Model(d, b, s, h) + out_tensors = fd.execute([inp_tensor, weight_tensor, bias_tensor]) # [b, s, d*h] unsharded_out_tensor = torch.nn.functional.linear( @@ -112,3 +114,81 @@ def multidevice_schedule(self): torch.testing.assert_close( out_tensors[0], expected_out_tensor, rtol=1.3e-6, atol=1e-4 ) + + +@pytest.mark.skipif( + utils.is_pre_ampere(), + reason="Flash Attention is only supported on Ampere and newer devices.", +) +@pytest.mark.mpi +def test_sdpa(mpi_test): + d, b, s, a, h = mpi_test.size, 2, 1024, 12, 768 + + if a % d != 0: + pytest.skip(f"We only support even split, so {a} has to be divisible by {d}.") + + class Model(FusionDefinition): + def definition(self) -> None: + self.q = self.define_tensor( + shape=[d, -1, -1, -1, -1], + contiguity=[True if d > 1 else None, True, True, True, True], + dtype=DataType.BFloat16, + ) + self.k = self.define_tensor( + shape=[d, -1, -1, -1, -1], + contiguity=[True if d > 1 else None, True, True, True, True], + dtype=DataType.BFloat16, + ) + self.v = self.define_tensor( + shape=[d, -1, -1, -1, -1], + contiguity=[True if d > 1 else None, True, True, True, True], + dtype=DataType.BFloat16, + ) + # TODO(#3123): support sharded dropout and change this to a + # positive probability. + dropout_p = self.define_scalar(0.0, dtype=DataType.Double) + is_causal = self.define_scalar(True, dtype=DataType.Bool) + sdpa_result = self.ops.sdpfa_fwd( + self.q, self.k, self.v, dropout_p, is_causal, scale=None + ) + attn = sdpa_result[0] + self.add_output(attn) + + def multidevice_schedule(self) -> None: + mesh = self.sched._create_device_mesh(range(d)) + for t in [self.q, self.k, self.v]: + self.sched._set_device_mesh(t, mesh) + self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) + + torch.cuda.set_device(mpi_test.local_rank) + torch.manual_seed(0) + q, k, v = [ + torch.randn(b, a, s, h // a, dtype=torch.bfloat16, device="cuda") + for _ in range(3) + ] + + with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION): + expected_attn = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + dropout_p=0.0, + is_causal=True, + scale=None, + ) + + rank = mpi_test.rank + + # Head-parallelize Q, K, V or the attention output of an SDPA. + def head_parallelize(t: torch.Tensor) -> torch.Tensor: + assert t.shape == torch.Size([b, a, s, h // a]) + return t.view([b, d, a // d, s, h // a]).transpose(0, 1)[rank : rank + 1] + + fd = Model() + attn = fd.execute([head_parallelize(q), head_parallelize(k), head_parallelize(v)])[ + 0 + ] + # Use the default rtol for bfloat16 and a relaxed atol. + torch.testing.assert_close( + attn, head_parallelize(expected_attn), rtol=1.6e-2, atol=1e-3 + )