diff --git a/tests/python/test_multidevice.py b/tests/python/test_multidevice.py index 4181067bcf8..2082cd4bb1e 100644 --- a/tests/python/test_multidevice.py +++ b/tests/python/test_multidevice.py @@ -234,17 +234,17 @@ def definition(self) -> None: dtype=DataType.BFloat16, ) self.mha_linear0_weight = self.define_tensor( - shape=[e * 3, e], + shape=[d, e * 3 // d, e], contiguity=True, dtype=DataType.BFloat16, ) self.mha_linear0_bias = self.define_tensor( - shape=[e * 3], + shape=[d, e * 3 // d], contiguity=True, dtype=DataType.BFloat16, ) self.mha_linear1_weight = self.define_tensor( - shape=[e, e], + shape=[d, e, e // d], contiguity=True, dtype=DataType.BFloat16, ) @@ -332,35 +332,49 @@ def definition(self) -> None: T78 = self.ops.linear(T77, self.mha_linear0_weight, self.mha_linear0_bias) T91 = self.ops.slice( T78, - start_indices=[0, 0, 0], - end_indices=[b, s, e], - strides=[1, 1, 1], + start_indices=[0, 0, 0, 0], + end_indices=[d, b, s, e // d], ) T104 = self.ops.slice( T78, - start_indices=[0, 0, e], - end_indices=[b, s, e * 2], - strides=[1, 1, 1], + start_indices=[0, 0, 0, e // d], + end_indices=[d, b, s, e * 2 // d], ) T117 = self.ops.slice( T78, - start_indices=[0, 0, e * 2], - end_indices=[b, s, e * 3], - strides=[1, 1, 1], + start_indices=[0, 0, 0, e * 2 // d], + end_indices=[d, b, s, e * 3 // d], ) - T123 = self.ops.reshape(T104, new_shape=[b, s, h, e // h]) - T124 = self.ops.permute(T123, dims=[0, 2, 1, 3]) - T130 = self.ops.reshape(T91, new_shape=[b, s, h, e // h]) - T131 = self.ops.permute(T130, dims=[0, 2, 1, 3]) - T137 = self.ops.reshape(T117, new_shape=[b, s, h, e // h]) - T138 = self.ops.permute(T137, dims=[0, 2, 1, 3]) + T123 = self.ops.reshape(T104, new_shape=[d, b, s, h // d, e // h]) + T124 = self.ops.permute(T123, dims=[0, 1, 3, 2, 4]) + T130 = self.ops.reshape(T91, new_shape=[d, b, s, h // d, e // h]) + T131 = self.ops.permute(T130, dims=[0, 1, 3, 2, 4]) + T137 = self.ops.reshape(T117, new_shape=[d, b, s, h // d, e // h]) + T138 = self.ops.permute(T137, dims=[0, 1, 3, 2, 4]) S139 = self.define_scalar(0.100000, dtype=DataType.Double) S140 = self.define_scalar(True, dtype=DataType.Bool) T141, T142, T143, T144 = self.ops.sdpfa_fwd(T131, T124, T138, S139, S140, None) - T145 = self.ops.permute(T141, dims=[0, 2, 1, 3]) - T146 = self.ops.stride_order(T145, stride_order=[3, 2, 1, 0]) - T151 = self.ops.reshape(T146, new_shape=[b, s, e]) - T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias) + T145 = self.ops.permute(T141, dims=[0, 1, 3, 2, 4]) + T146 = self.ops.stride_order(T145, stride_order=[4, 3, 2, 1, 0]) + T151 = self.ops.reshape(T146, new_shape=[d, b, s, e // d]) + # TODO(#3125): nvFuser is missing an API to construct a sharded linear + # like this. Therefore, I decomposed it by hand. + # T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias) + # [d,b,s,e/d] [d,e,e/d] [e] + T152_local_matmul = self.ops.matmul( + T151, + self.ops.broadcast_in_dim( + self.ops.permute(self.mha_linear1_weight, [0, 2, 1]), + [d, 1, e // d, e], + [0, 2, 3], + ), + ) + T152_matmul = self.ops.sum(T152_local_matmul, [0]) # allreduce + T152_biasadd = self.ops.add( + T152_matmul, + self.ops.broadcast_in_dim(self.mha_linear1_bias, [1, 1, e], [2]), + ) + T152 = self.ops.cast(T152_biasadd, dtype=DataType.BFloat16) T153 = self.ops.cast(T152, dtype=DataType.Float) T154 = self.ops.cast(T33, dtype=DataType.Float) T155 = self.ops.mul(T153, T154) @@ -408,8 +422,7 @@ def definition(self) -> None: T214 = self.ops.add(S213, T210) T215 = self.ops.mul(T212, T214) T216 = self.ops.cast(T215, dtype=DataType.BFloat16) - # TODO(#3125): nvFuser is missing an API to construct a sharded linear - # like this. Therefore, I decomposed it by hand. + # TODO(#3125): same as mha_linear1. # T217 = self.ops.linear(T216, self.mlp_linear1_weight, self.mlp_linear1_bias) # [b,s,e] [d,b,s,4h/d] [d,e,4h/d] [e] T217_local_matmul = self.ops.matmul( @@ -448,6 +461,8 @@ def definition(self) -> None: def multidevice_schedule(self): mesh = self.sched._create_device_mesh(range(self._num_devices)) + # Assign the mesh to inputs and weights. nvFuser will propagate it to + # downstream tensors. for in_tv in [ self.input, self.layernorm0_weight, @@ -465,9 +480,17 @@ def multidevice_schedule(self): ]: self.sched._set_device_mesh(in_tv, mesh) - self.sched.parallelize(self.mlp_linear0_weight, 0, nvfuser.ParallelType.mesh_x) - self.sched.parallelize(self.mlp_linear0_bias, 0, nvfuser.ParallelType.mesh_x) - self.sched.parallelize(self.mlp_linear1_weight, 0, nvfuser.ParallelType.mesh_x) + # Parallelize the device dimension of certain weights. nvFuser will try + # to propagate shardings to downstream tensors. + for in_tv in [ + self.mha_linear0_weight, + self.mha_linear0_bias, + self.mha_linear1_weight, + self.mlp_linear0_weight, + self.mlp_linear0_bias, + self.mlp_linear1_weight, + ]: + self.sched.parallelize(in_tv, 0, nvfuser.ParallelType.mesh_x) @pytest.mark.skipif( @@ -479,20 +502,35 @@ def test_transformer_forward(mpi_test): d = mpi_test.size rank = mpi_test.rank - b, s, h, e = 2, 2048, 96, 12288 + b, s, h, e = 1, 2048, 96, 12288 + + assert ( + e % h == 0 + ), f"The hidden size ({e}) has to be divisible by the number of heads ({h})." - if e * 4 % d != 0: + if h % d != 0: pytest.skip( - f"We only support even split, so {e} * 4 has to be divisible by {d}." + f"We only support even DID split, so the number of heads ({h}) has to be divisible by the number of GPUs ({d})." ) + assert e * 4 % d == 0, ( + "This is required to evenly DID split MLP. This condition is implied " + "by the previous two checks; a fail would indicate a programming " + "error. So I use `assert` instead of `pytest.skip`." + ) + torch.cuda.set_device(mpi_test.local_rank) # To reduce memory footprint, create unsharded data on CPU and copy only # the needed slice to GPU. + mha_linear0_weight = torch.randn(d, e * 3 // d, e, dtype=torch.bfloat16) + mha_linear0_bias = torch.randn(d, e * 3 // d, dtype=torch.bfloat16) + mha_linear1_weight = torch.randn(d, e, e // d, dtype=torch.bfloat16) + mha_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda") mlp_linear0_weight = torch.randn(d, e * 4 // d, e, dtype=torch.bfloat16) mlp_linear0_bias = torch.randn(d, e * 4 // d, dtype=torch.bfloat16) mlp_linear1_weight = torch.randn(d, e, e * 4 // d, dtype=torch.bfloat16) + mlp_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda") # See TransformerForwardFusion.definition for the meanings of these # arguments. They are passed in in the same order as the `define_scalar`s # and `define_tensor`s. @@ -504,16 +542,16 @@ def test_transformer_forward(mpi_test): torch.randn(b, s, e, dtype=torch.bfloat16, device="cuda"), torch.randn(e, dtype=torch.bfloat16, device="cuda"), torch.randn(e, dtype=torch.bfloat16, device="cuda"), - torch.randn(e * 3, e, dtype=torch.bfloat16, device="cuda"), - torch.randn(e * 3, dtype=torch.bfloat16, device="cuda"), - torch.randn(e, e, dtype=torch.bfloat16, device="cuda"), - torch.randn(e, dtype=torch.bfloat16, device="cuda"), + mha_linear0_weight[rank : rank + 1].cuda(), + mha_linear0_bias[rank : rank + 1].cuda(), + mha_linear1_weight[rank : rank + 1].cuda(), + mha_linear1_bias, torch.randn(e, dtype=torch.bfloat16, device="cuda"), torch.randn(e, dtype=torch.bfloat16, device="cuda"), mlp_linear0_weight[rank : rank + 1].cuda(), mlp_linear0_bias[rank : rank + 1].cuda(), mlp_linear1_weight[rank : rank + 1].cuda(), - torch.randn(e, dtype=torch.bfloat16, device="cuda"), + mlp_linear1_bias, ] fd = TransformerForwardFusion(d, b, s, h, e) @@ -530,7 +568,7 @@ def test_transformer_forward(mpi_test): mha_dropout, layernorm1_avg, layernorm1_invstd, - output, + out, ) = outs # TODO(#2962): validate the numbers as well. Currently, the numbers are off @@ -543,12 +581,12 @@ def assert_shape_dtype( assert_shape_dtype(layernorm0_avg, [b, s], torch.float32) assert_shape_dtype(layernorm0_invstd, [b, s, 1], torch.float32) - assert_shape_dtype(mha_linear0, [b, s, e * 3], torch.bfloat16) - assert_shape_dtype(sdpa_out, [b, h, s, e // h], torch.bfloat16) - assert_shape_dtype(sdpa_logsum_exp, [b, h, s], torch.float32) + assert_shape_dtype(mha_linear0, [1, b, s, e * 3 // d], torch.bfloat16) + assert_shape_dtype(sdpa_out, [1, b, h // d, s, e // h], torch.bfloat16) + assert_shape_dtype(sdpa_logsum_exp, [1, b, h // d, s], torch.float32) assert_shape_dtype(sdpa_seed, [], torch.int64) assert_shape_dtype(sdpa_offset, [], torch.int64) assert_shape_dtype(mha_dropout, [b, s, e], torch.float32) assert_shape_dtype(layernorm1_avg, [b, s], torch.float32) assert_shape_dtype(layernorm1_invstd, [b, s, 1], torch.float32) - assert_shape_dtype(output, [b, s, e], torch.bfloat16) + assert_shape_dtype(out, [b, s, e], torch.bfloat16)