-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shard MHA. #3115
Shard MHA. #3115
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -234,17 +234,17 @@ def definition(self) -> None: | |
dtype=DataType.BFloat16, | ||
) | ||
self.mha_linear0_weight = self.define_tensor( | ||
shape=[e * 3, e], | ||
shape=[d, e * 3 // d, e], | ||
contiguity=True, | ||
dtype=DataType.BFloat16, | ||
) | ||
self.mha_linear0_bias = self.define_tensor( | ||
shape=[e * 3], | ||
shape=[d, e * 3 // d], | ||
contiguity=True, | ||
dtype=DataType.BFloat16, | ||
) | ||
self.mha_linear1_weight = self.define_tensor( | ||
shape=[e, e], | ||
shape=[d, e, e // d], | ||
contiguity=True, | ||
dtype=DataType.BFloat16, | ||
) | ||
|
@@ -332,35 +332,49 @@ def definition(self) -> None: | |
T78 = self.ops.linear(T77, self.mha_linear0_weight, self.mha_linear0_bias) | ||
T91 = self.ops.slice( | ||
T78, | ||
start_indices=[0, 0, 0], | ||
end_indices=[b, s, e], | ||
strides=[1, 1, 1], | ||
start_indices=[0, 0, 0, 0], | ||
end_indices=[d, b, s, e // d], | ||
) | ||
T104 = self.ops.slice( | ||
T78, | ||
start_indices=[0, 0, e], | ||
end_indices=[b, s, e * 2], | ||
strides=[1, 1, 1], | ||
start_indices=[0, 0, 0, e // d], | ||
end_indices=[d, b, s, e * 2 // d], | ||
) | ||
T117 = self.ops.slice( | ||
T78, | ||
start_indices=[0, 0, e * 2], | ||
end_indices=[b, s, e * 3], | ||
strides=[1, 1, 1], | ||
start_indices=[0, 0, 0, e * 2 // d], | ||
end_indices=[d, b, s, e * 3 // d], | ||
) | ||
T123 = self.ops.reshape(T104, new_shape=[b, s, h, e // h]) | ||
T124 = self.ops.permute(T123, dims=[0, 2, 1, 3]) | ||
T130 = self.ops.reshape(T91, new_shape=[b, s, h, e // h]) | ||
T131 = self.ops.permute(T130, dims=[0, 2, 1, 3]) | ||
T137 = self.ops.reshape(T117, new_shape=[b, s, h, e // h]) | ||
T138 = self.ops.permute(T137, dims=[0, 2, 1, 3]) | ||
T123 = self.ops.reshape(T104, new_shape=[d, b, s, h // d, e // h]) | ||
T124 = self.ops.permute(T123, dims=[0, 1, 3, 2, 4]) | ||
T130 = self.ops.reshape(T91, new_shape=[d, b, s, h // d, e // h]) | ||
T131 = self.ops.permute(T130, dims=[0, 1, 3, 2, 4]) | ||
T137 = self.ops.reshape(T117, new_shape=[d, b, s, h // d, e // h]) | ||
T138 = self.ops.permute(T137, dims=[0, 1, 3, 2, 4]) | ||
S139 = self.define_scalar(0.100000, dtype=DataType.Double) | ||
S140 = self.define_scalar(True, dtype=DataType.Bool) | ||
T141, T142, T143, T144 = self.ops.sdpfa_fwd(T131, T124, T138, S139, S140, None) | ||
T145 = self.ops.permute(T141, dims=[0, 2, 1, 3]) | ||
T146 = self.ops.stride_order(T145, stride_order=[3, 2, 1, 0]) | ||
T151 = self.ops.reshape(T146, new_shape=[b, s, e]) | ||
T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias) | ||
T145 = self.ops.permute(T141, dims=[0, 1, 3, 2, 4]) | ||
T146 = self.ops.stride_order(T145, stride_order=[4, 3, 2, 1, 0]) | ||
T151 = self.ops.reshape(T146, new_shape=[d, b, s, e // d]) | ||
# TODO(#3125): nvFuser is missing an API to construct a sharded linear | ||
# like this. Therefore, I decomposed it by hand. | ||
# T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias) | ||
# [d,b,s,e/d] [d,e,e/d] [e] | ||
T152_local_matmul = self.ops.matmul( | ||
T151, | ||
self.ops.broadcast_in_dim( | ||
self.ops.permute(self.mha_linear1_weight, [0, 2, 1]), | ||
[d, 1, e // d, e], | ||
[0, 2, 3], | ||
), | ||
) | ||
T152_matmul = self.ops.sum(T152_local_matmul, [0]) # allreduce | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder what would happen currently if we do not decompose the matmul and the allreduce... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first thing that'll break is that |
||
T152_biasadd = self.ops.add( | ||
T152_matmul, | ||
self.ops.broadcast_in_dim(self.mha_linear1_bias, [1, 1, e], [2]), | ||
) | ||
T152 = self.ops.cast(T152_biasadd, dtype=DataType.BFloat16) | ||
T153 = self.ops.cast(T152, dtype=DataType.Float) | ||
T154 = self.ops.cast(T33, dtype=DataType.Float) | ||
T155 = self.ops.mul(T153, T154) | ||
|
@@ -408,8 +422,7 @@ def definition(self) -> None: | |
T214 = self.ops.add(S213, T210) | ||
T215 = self.ops.mul(T212, T214) | ||
T216 = self.ops.cast(T215, dtype=DataType.BFloat16) | ||
# TODO(#3125): nvFuser is missing an API to construct a sharded linear | ||
# like this. Therefore, I decomposed it by hand. | ||
# TODO(#3125): same as mha_linear1. | ||
# T217 = self.ops.linear(T216, self.mlp_linear1_weight, self.mlp_linear1_bias) | ||
# [b,s,e] [d,b,s,4h/d] [d,e,4h/d] [e] | ||
T217_local_matmul = self.ops.matmul( | ||
|
@@ -448,6 +461,8 @@ def definition(self) -> None: | |
|
||
def multidevice_schedule(self): | ||
mesh = self.sched._create_device_mesh(range(self._num_devices)) | ||
# Assign the mesh to inputs and weights. nvFuser will propagate it to | ||
# downstream tensors. | ||
for in_tv in [ | ||
self.input, | ||
self.layernorm0_weight, | ||
|
@@ -465,9 +480,17 @@ def multidevice_schedule(self): | |
]: | ||
self.sched._set_device_mesh(in_tv, mesh) | ||
|
||
self.sched.parallelize(self.mlp_linear0_weight, 0, nvfuser.ParallelType.mesh_x) | ||
self.sched.parallelize(self.mlp_linear0_bias, 0, nvfuser.ParallelType.mesh_x) | ||
self.sched.parallelize(self.mlp_linear1_weight, 0, nvfuser.ParallelType.mesh_x) | ||
# Parallelize the device dimension of certain weights. nvFuser will try | ||
# to propagate shardings to downstream tensors. | ||
for in_tv in [ | ||
self.mha_linear0_weight, | ||
self.mha_linear0_bias, | ||
self.mha_linear1_weight, | ||
self.mlp_linear0_weight, | ||
self.mlp_linear0_bias, | ||
self.mlp_linear1_weight, | ||
]: | ||
self.sched.parallelize(in_tv, 0, nvfuser.ParallelType.mesh_x) | ||
|
||
|
||
@pytest.mark.skipif( | ||
|
@@ -479,20 +502,35 @@ def test_transformer_forward(mpi_test): | |
d = mpi_test.size | ||
rank = mpi_test.rank | ||
|
||
b, s, h, e = 2, 2048, 96, 12288 | ||
b, s, h, e = 1, 2048, 96, 12288 | ||
|
||
assert ( | ||
e % h == 0 | ||
), f"The hidden size ({e}) has to be divisible by the number of heads ({h})." | ||
|
||
if e * 4 % d != 0: | ||
if h % d != 0: | ||
pytest.skip( | ||
f"We only support even split, so {e} * 4 has to be divisible by {d}." | ||
f"We only support even DID split, so the number of heads ({h}) has to be divisible by the number of GPUs ({d})." | ||
) | ||
|
||
assert e * 4 % d == 0, ( | ||
"This is required to evenly DID split MLP. This condition is implied " | ||
"by the previous two checks; a fail would indicate a programming " | ||
"error. So I use `assert` instead of `pytest.skip`." | ||
) | ||
|
||
torch.cuda.set_device(mpi_test.local_rank) | ||
|
||
# To reduce memory footprint, create unsharded data on CPU and copy only | ||
# the needed slice to GPU. | ||
mha_linear0_weight = torch.randn(d, e * 3 // d, e, dtype=torch.bfloat16) | ||
mha_linear0_bias = torch.randn(d, e * 3 // d, dtype=torch.bfloat16) | ||
mha_linear1_weight = torch.randn(d, e, e // d, dtype=torch.bfloat16) | ||
mha_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda") | ||
mlp_linear0_weight = torch.randn(d, e * 4 // d, e, dtype=torch.bfloat16) | ||
mlp_linear0_bias = torch.randn(d, e * 4 // d, dtype=torch.bfloat16) | ||
mlp_linear1_weight = torch.randn(d, e, e * 4 // d, dtype=torch.bfloat16) | ||
mlp_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda") | ||
# See TransformerForwardFusion.definition for the meanings of these | ||
# arguments. They are passed in in the same order as the `define_scalar`s | ||
# and `define_tensor`s. | ||
|
@@ -504,16 +542,16 @@ def test_transformer_forward(mpi_test): | |
torch.randn(b, s, e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e * 3, e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e * 3, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e, e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
mha_linear0_weight[rank : rank + 1].cuda(), | ||
mha_linear0_bias[rank : rank + 1].cuda(), | ||
mha_linear1_weight[rank : rank + 1].cuda(), | ||
mha_linear1_bias, | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
mlp_linear0_weight[rank : rank + 1].cuda(), | ||
mlp_linear0_bias[rank : rank + 1].cuda(), | ||
mlp_linear1_weight[rank : rank + 1].cuda(), | ||
torch.randn(e, dtype=torch.bfloat16, device="cuda"), | ||
mlp_linear1_bias, | ||
] | ||
|
||
fd = TransformerForwardFusion(d, b, s, h, e) | ||
|
@@ -530,7 +568,7 @@ def test_transformer_forward(mpi_test): | |
mha_dropout, | ||
layernorm1_avg, | ||
layernorm1_invstd, | ||
output, | ||
out, | ||
) = outs | ||
|
||
# TODO(#2962): validate the numbers as well. Currently, the numbers are off | ||
|
@@ -543,12 +581,12 @@ def assert_shape_dtype( | |
|
||
assert_shape_dtype(layernorm0_avg, [b, s], torch.float32) | ||
assert_shape_dtype(layernorm0_invstd, [b, s, 1], torch.float32) | ||
assert_shape_dtype(mha_linear0, [b, s, e * 3], torch.bfloat16) | ||
assert_shape_dtype(sdpa_out, [b, h, s, e // h], torch.bfloat16) | ||
assert_shape_dtype(sdpa_logsum_exp, [b, h, s], torch.float32) | ||
assert_shape_dtype(mha_linear0, [1, b, s, e * 3 // d], torch.bfloat16) | ||
assert_shape_dtype(sdpa_out, [1, b, h // d, s, e // h], torch.bfloat16) | ||
assert_shape_dtype(sdpa_logsum_exp, [1, b, h // d, s], torch.float32) | ||
assert_shape_dtype(sdpa_seed, [], torch.int64) | ||
assert_shape_dtype(sdpa_offset, [], torch.int64) | ||
assert_shape_dtype(mha_dropout, [b, s, e], torch.float32) | ||
assert_shape_dtype(layernorm1_avg, [b, s], torch.float32) | ||
assert_shape_dtype(layernorm1_invstd, [b, s, 1], torch.float32) | ||
assert_shape_dtype(output, [b, s, e], torch.bfloat16) | ||
assert_shape_dtype(out, [b, s, e], torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we pass from shape
[d, b, s, e//d]
to[d, b, s, h//d, e//h]
. Nothing illegal about it but it looks surprising to me so I just want to make sureThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I double checked -- it looks right. MHA is head parallel according to Figure 3b in https://arxiv.org/pdf/1909.08053.