Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shard MHA. #3115

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 78 additions & 40 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,17 @@ def definition(self) -> None:
dtype=DataType.BFloat16,
)
self.mha_linear0_weight = self.define_tensor(
shape=[e * 3, e],
shape=[d, e * 3 // d, e],
contiguity=True,
dtype=DataType.BFloat16,
)
self.mha_linear0_bias = self.define_tensor(
shape=[e * 3],
shape=[d, e * 3 // d],
contiguity=True,
dtype=DataType.BFloat16,
)
self.mha_linear1_weight = self.define_tensor(
shape=[e, e],
shape=[d, e, e // d],
contiguity=True,
dtype=DataType.BFloat16,
)
Expand Down Expand Up @@ -332,35 +332,49 @@ def definition(self) -> None:
T78 = self.ops.linear(T77, self.mha_linear0_weight, self.mha_linear0_bias)
T91 = self.ops.slice(
T78,
start_indices=[0, 0, 0],
end_indices=[b, s, e],
strides=[1, 1, 1],
start_indices=[0, 0, 0, 0],
end_indices=[d, b, s, e // d],
)
T104 = self.ops.slice(
T78,
start_indices=[0, 0, e],
end_indices=[b, s, e * 2],
strides=[1, 1, 1],
start_indices=[0, 0, 0, e // d],
end_indices=[d, b, s, e * 2 // d],
)
T117 = self.ops.slice(
T78,
start_indices=[0, 0, e * 2],
end_indices=[b, s, e * 3],
strides=[1, 1, 1],
start_indices=[0, 0, 0, e * 2 // d],
end_indices=[d, b, s, e * 3 // d],
)
T123 = self.ops.reshape(T104, new_shape=[b, s, h, e // h])
T124 = self.ops.permute(T123, dims=[0, 2, 1, 3])
T130 = self.ops.reshape(T91, new_shape=[b, s, h, e // h])
T131 = self.ops.permute(T130, dims=[0, 2, 1, 3])
T137 = self.ops.reshape(T117, new_shape=[b, s, h, e // h])
T138 = self.ops.permute(T137, dims=[0, 2, 1, 3])
T123 = self.ops.reshape(T104, new_shape=[d, b, s, h // d, e // h])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we pass from shape [d, b, s, e//d] to [d, b, s, h//d, e//h]. Nothing illegal about it but it looks surprising to me so I just want to make sure

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked -- it looks right. MHA is head parallel according to Figure 3b in https://arxiv.org/pdf/1909.08053.

T124 = self.ops.permute(T123, dims=[0, 1, 3, 2, 4])
T130 = self.ops.reshape(T91, new_shape=[d, b, s, h // d, e // h])
T131 = self.ops.permute(T130, dims=[0, 1, 3, 2, 4])
T137 = self.ops.reshape(T117, new_shape=[d, b, s, h // d, e // h])
T138 = self.ops.permute(T137, dims=[0, 1, 3, 2, 4])
S139 = self.define_scalar(0.100000, dtype=DataType.Double)
S140 = self.define_scalar(True, dtype=DataType.Bool)
T141, T142, T143, T144 = self.ops.sdpfa_fwd(T131, T124, T138, S139, S140, None)
T145 = self.ops.permute(T141, dims=[0, 2, 1, 3])
T146 = self.ops.stride_order(T145, stride_order=[3, 2, 1, 0])
T151 = self.ops.reshape(T146, new_shape=[b, s, e])
T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias)
T145 = self.ops.permute(T141, dims=[0, 1, 3, 2, 4])
T146 = self.ops.stride_order(T145, stride_order=[4, 3, 2, 1, 0])
T151 = self.ops.reshape(T146, new_shape=[d, b, s, e // d])
# TODO(#3125): nvFuser is missing an API to construct a sharded linear
# like this. Therefore, I decomposed it by hand.
# T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias)
# [d,b,s,e/d] [d,e,e/d] [e]
T152_local_matmul = self.ops.matmul(
T151,
self.ops.broadcast_in_dim(
self.ops.permute(self.mha_linear1_weight, [0, 2, 1]),
[d, 1, e // d, e],
[0, 2, 3],
),
)
T152_matmul = self.ops.sum(T152_local_matmul, [0]) # allreduce
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what would happen currently if we do not decompose the matmul and the allreduce...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first thing that'll break is that linear will produce a wrong shape. linear, as is implemented today, will output a tensor of rank input_rank + weight_rank - 2 = 5. However, we want the shape to be [d,b,s,e] and thus 4D.

T152_biasadd = self.ops.add(
T152_matmul,
self.ops.broadcast_in_dim(self.mha_linear1_bias, [1, 1, e], [2]),
)
T152 = self.ops.cast(T152_biasadd, dtype=DataType.BFloat16)
T153 = self.ops.cast(T152, dtype=DataType.Float)
T154 = self.ops.cast(T33, dtype=DataType.Float)
T155 = self.ops.mul(T153, T154)
Expand Down Expand Up @@ -408,8 +422,7 @@ def definition(self) -> None:
T214 = self.ops.add(S213, T210)
T215 = self.ops.mul(T212, T214)
T216 = self.ops.cast(T215, dtype=DataType.BFloat16)
# TODO(#3125): nvFuser is missing an API to construct a sharded linear
# like this. Therefore, I decomposed it by hand.
# TODO(#3125): same as mha_linear1.
# T217 = self.ops.linear(T216, self.mlp_linear1_weight, self.mlp_linear1_bias)
# [b,s,e] [d,b,s,4h/d] [d,e,4h/d] [e]
T217_local_matmul = self.ops.matmul(
Expand Down Expand Up @@ -448,6 +461,8 @@ def definition(self) -> None:

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(self._num_devices))
# Assign the mesh to inputs and weights. nvFuser will propagate it to
# downstream tensors.
for in_tv in [
self.input,
self.layernorm0_weight,
Expand All @@ -465,9 +480,17 @@ def multidevice_schedule(self):
]:
self.sched._set_device_mesh(in_tv, mesh)

self.sched.parallelize(self.mlp_linear0_weight, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(self.mlp_linear0_bias, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(self.mlp_linear1_weight, 0, nvfuser.ParallelType.mesh_x)
# Parallelize the device dimension of certain weights. nvFuser will try
# to propagate shardings to downstream tensors.
for in_tv in [
self.mha_linear0_weight,
self.mha_linear0_bias,
self.mha_linear1_weight,
self.mlp_linear0_weight,
self.mlp_linear0_bias,
self.mlp_linear1_weight,
]:
self.sched.parallelize(in_tv, 0, nvfuser.ParallelType.mesh_x)


@pytest.mark.skipif(
Expand All @@ -479,20 +502,35 @@ def test_transformer_forward(mpi_test):
d = mpi_test.size
rank = mpi_test.rank

b, s, h, e = 2, 2048, 96, 12288
b, s, h, e = 1, 2048, 96, 12288

assert (
e % h == 0
), f"The hidden size ({e}) has to be divisible by the number of heads ({h})."

if e * 4 % d != 0:
if h % d != 0:
pytest.skip(
f"We only support even split, so {e} * 4 has to be divisible by {d}."
f"We only support even DID split, so the number of heads ({h}) has to be divisible by the number of GPUs ({d})."
)

assert e * 4 % d == 0, (
"This is required to evenly DID split MLP. This condition is implied "
"by the previous two checks; a fail would indicate a programming "
"error. So I use `assert` instead of `pytest.skip`."
)

torch.cuda.set_device(mpi_test.local_rank)

# To reduce memory footprint, create unsharded data on CPU and copy only
# the needed slice to GPU.
mha_linear0_weight = torch.randn(d, e * 3 // d, e, dtype=torch.bfloat16)
mha_linear0_bias = torch.randn(d, e * 3 // d, dtype=torch.bfloat16)
mha_linear1_weight = torch.randn(d, e, e // d, dtype=torch.bfloat16)
mha_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda")
mlp_linear0_weight = torch.randn(d, e * 4 // d, e, dtype=torch.bfloat16)
mlp_linear0_bias = torch.randn(d, e * 4 // d, dtype=torch.bfloat16)
mlp_linear1_weight = torch.randn(d, e, e * 4 // d, dtype=torch.bfloat16)
mlp_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda")
# See TransformerForwardFusion.definition for the meanings of these
# arguments. They are passed in in the same order as the `define_scalar`s
# and `define_tensor`s.
Expand All @@ -504,16 +542,16 @@ def test_transformer_forward(mpi_test):
torch.randn(b, s, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e * 3, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e * 3, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mha_linear0_weight[rank : rank + 1].cuda(),
mha_linear0_bias[rank : rank + 1].cuda(),
mha_linear1_weight[rank : rank + 1].cuda(),
mha_linear1_bias,
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mlp_linear0_weight[rank : rank + 1].cuda(),
mlp_linear0_bias[rank : rank + 1].cuda(),
mlp_linear1_weight[rank : rank + 1].cuda(),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mlp_linear1_bias,
]

fd = TransformerForwardFusion(d, b, s, h, e)
Expand All @@ -530,7 +568,7 @@ def test_transformer_forward(mpi_test):
mha_dropout,
layernorm1_avg,
layernorm1_invstd,
output,
out,
) = outs

# TODO(#2962): validate the numbers as well. Currently, the numbers are off
Expand All @@ -543,12 +581,12 @@ def assert_shape_dtype(

assert_shape_dtype(layernorm0_avg, [b, s], torch.float32)
assert_shape_dtype(layernorm0_invstd, [b, s, 1], torch.float32)
assert_shape_dtype(mha_linear0, [b, s, e * 3], torch.bfloat16)
assert_shape_dtype(sdpa_out, [b, h, s, e // h], torch.bfloat16)
assert_shape_dtype(sdpa_logsum_exp, [b, h, s], torch.float32)
assert_shape_dtype(mha_linear0, [1, b, s, e * 3 // d], torch.bfloat16)
assert_shape_dtype(sdpa_out, [1, b, h // d, s, e // h], torch.bfloat16)
assert_shape_dtype(sdpa_logsum_exp, [1, b, h // d, s], torch.float32)
assert_shape_dtype(sdpa_seed, [], torch.int64)
assert_shape_dtype(sdpa_offset, [], torch.int64)
assert_shape_dtype(mha_dropout, [b, s, e], torch.float32)
assert_shape_dtype(layernorm1_avg, [b, s], torch.float32)
assert_shape_dtype(layernorm1_invstd, [b, s, 1], torch.float32)
assert_shape_dtype(output, [b, s, e], torch.bfloat16)
assert_shape_dtype(out, [b, s, e], torch.bfloat16)
Loading