Skip to content

Commit

Permalink
Shard MHA. (#3115)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Oct 15, 2024
1 parent 5556357 commit 66c4bed
Showing 1 changed file with 78 additions and 40 deletions.
118 changes: 78 additions & 40 deletions tests/python/test_multidevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,17 @@ def definition(self) -> None:
dtype=DataType.BFloat16,
)
self.mha_linear0_weight = self.define_tensor(
shape=[e * 3, e],
shape=[d, e * 3 // d, e],
contiguity=True,
dtype=DataType.BFloat16,
)
self.mha_linear0_bias = self.define_tensor(
shape=[e * 3],
shape=[d, e * 3 // d],
contiguity=True,
dtype=DataType.BFloat16,
)
self.mha_linear1_weight = self.define_tensor(
shape=[e, e],
shape=[d, e, e // d],
contiguity=True,
dtype=DataType.BFloat16,
)
Expand Down Expand Up @@ -332,35 +332,49 @@ def definition(self) -> None:
T78 = self.ops.linear(T77, self.mha_linear0_weight, self.mha_linear0_bias)
T91 = self.ops.slice(
T78,
start_indices=[0, 0, 0],
end_indices=[b, s, e],
strides=[1, 1, 1],
start_indices=[0, 0, 0, 0],
end_indices=[d, b, s, e // d],
)
T104 = self.ops.slice(
T78,
start_indices=[0, 0, e],
end_indices=[b, s, e * 2],
strides=[1, 1, 1],
start_indices=[0, 0, 0, e // d],
end_indices=[d, b, s, e * 2 // d],
)
T117 = self.ops.slice(
T78,
start_indices=[0, 0, e * 2],
end_indices=[b, s, e * 3],
strides=[1, 1, 1],
start_indices=[0, 0, 0, e * 2 // d],
end_indices=[d, b, s, e * 3 // d],
)
T123 = self.ops.reshape(T104, new_shape=[b, s, h, e // h])
T124 = self.ops.permute(T123, dims=[0, 2, 1, 3])
T130 = self.ops.reshape(T91, new_shape=[b, s, h, e // h])
T131 = self.ops.permute(T130, dims=[0, 2, 1, 3])
T137 = self.ops.reshape(T117, new_shape=[b, s, h, e // h])
T138 = self.ops.permute(T137, dims=[0, 2, 1, 3])
T123 = self.ops.reshape(T104, new_shape=[d, b, s, h // d, e // h])
T124 = self.ops.permute(T123, dims=[0, 1, 3, 2, 4])
T130 = self.ops.reshape(T91, new_shape=[d, b, s, h // d, e // h])
T131 = self.ops.permute(T130, dims=[0, 1, 3, 2, 4])
T137 = self.ops.reshape(T117, new_shape=[d, b, s, h // d, e // h])
T138 = self.ops.permute(T137, dims=[0, 1, 3, 2, 4])
S139 = self.define_scalar(0.100000, dtype=DataType.Double)
S140 = self.define_scalar(True, dtype=DataType.Bool)
T141, T142, T143, T144 = self.ops.sdpfa_fwd(T131, T124, T138, S139, S140, None)
T145 = self.ops.permute(T141, dims=[0, 2, 1, 3])
T146 = self.ops.stride_order(T145, stride_order=[3, 2, 1, 0])
T151 = self.ops.reshape(T146, new_shape=[b, s, e])
T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias)
T145 = self.ops.permute(T141, dims=[0, 1, 3, 2, 4])
T146 = self.ops.stride_order(T145, stride_order=[4, 3, 2, 1, 0])
T151 = self.ops.reshape(T146, new_shape=[d, b, s, e // d])
# TODO(#3125): nvFuser is missing an API to construct a sharded linear
# like this. Therefore, I decomposed it by hand.
# T152 = self.ops.linear(T151, self.mha_linear1_weight, self.mha_linear1_bias)
# [d,b,s,e/d] [d,e,e/d] [e]
T152_local_matmul = self.ops.matmul(
T151,
self.ops.broadcast_in_dim(
self.ops.permute(self.mha_linear1_weight, [0, 2, 1]),
[d, 1, e // d, e],
[0, 2, 3],
),
)
T152_matmul = self.ops.sum(T152_local_matmul, [0]) # allreduce
T152_biasadd = self.ops.add(
T152_matmul,
self.ops.broadcast_in_dim(self.mha_linear1_bias, [1, 1, e], [2]),
)
T152 = self.ops.cast(T152_biasadd, dtype=DataType.BFloat16)
T153 = self.ops.cast(T152, dtype=DataType.Float)
T154 = self.ops.cast(T33, dtype=DataType.Float)
T155 = self.ops.mul(T153, T154)
Expand Down Expand Up @@ -408,8 +422,7 @@ def definition(self) -> None:
T214 = self.ops.add(S213, T210)
T215 = self.ops.mul(T212, T214)
T216 = self.ops.cast(T215, dtype=DataType.BFloat16)
# TODO(#3125): nvFuser is missing an API to construct a sharded linear
# like this. Therefore, I decomposed it by hand.
# TODO(#3125): same as mha_linear1.
# T217 = self.ops.linear(T216, self.mlp_linear1_weight, self.mlp_linear1_bias)
# [b,s,e] [d,b,s,4h/d] [d,e,4h/d] [e]
T217_local_matmul = self.ops.matmul(
Expand Down Expand Up @@ -448,6 +461,8 @@ def definition(self) -> None:

def multidevice_schedule(self):
mesh = self.sched._create_device_mesh(range(self._num_devices))
# Assign the mesh to inputs and weights. nvFuser will propagate it to
# downstream tensors.
for in_tv in [
self.input,
self.layernorm0_weight,
Expand All @@ -465,9 +480,17 @@ def multidevice_schedule(self):
]:
self.sched._set_device_mesh(in_tv, mesh)

self.sched.parallelize(self.mlp_linear0_weight, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(self.mlp_linear0_bias, 0, nvfuser.ParallelType.mesh_x)
self.sched.parallelize(self.mlp_linear1_weight, 0, nvfuser.ParallelType.mesh_x)
# Parallelize the device dimension of certain weights. nvFuser will try
# to propagate shardings to downstream tensors.
for in_tv in [
self.mha_linear0_weight,
self.mha_linear0_bias,
self.mha_linear1_weight,
self.mlp_linear0_weight,
self.mlp_linear0_bias,
self.mlp_linear1_weight,
]:
self.sched.parallelize(in_tv, 0, nvfuser.ParallelType.mesh_x)


@pytest.mark.skipif(
Expand All @@ -479,20 +502,35 @@ def test_transformer_forward(mpi_test):
d = mpi_test.size
rank = mpi_test.rank

b, s, h, e = 2, 2048, 96, 12288
b, s, h, e = 1, 2048, 96, 12288

assert (
e % h == 0
), f"The hidden size ({e}) has to be divisible by the number of heads ({h})."

if e * 4 % d != 0:
if h % d != 0:
pytest.skip(
f"We only support even split, so {e} * 4 has to be divisible by {d}."
f"We only support even DID split, so the number of heads ({h}) has to be divisible by the number of GPUs ({d})."
)

assert e * 4 % d == 0, (
"This is required to evenly DID split MLP. This condition is implied "
"by the previous two checks; a fail would indicate a programming "
"error. So I use `assert` instead of `pytest.skip`."
)

torch.cuda.set_device(mpi_test.local_rank)

# To reduce memory footprint, create unsharded data on CPU and copy only
# the needed slice to GPU.
mha_linear0_weight = torch.randn(d, e * 3 // d, e, dtype=torch.bfloat16)
mha_linear0_bias = torch.randn(d, e * 3 // d, dtype=torch.bfloat16)
mha_linear1_weight = torch.randn(d, e, e // d, dtype=torch.bfloat16)
mha_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda")
mlp_linear0_weight = torch.randn(d, e * 4 // d, e, dtype=torch.bfloat16)
mlp_linear0_bias = torch.randn(d, e * 4 // d, dtype=torch.bfloat16)
mlp_linear1_weight = torch.randn(d, e, e * 4 // d, dtype=torch.bfloat16)
mlp_linear1_bias = torch.randn(e, dtype=torch.bfloat16, device="cuda")
# See TransformerForwardFusion.definition for the meanings of these
# arguments. They are passed in in the same order as the `define_scalar`s
# and `define_tensor`s.
Expand All @@ -504,16 +542,16 @@ def test_transformer_forward(mpi_test):
torch.randn(b, s, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e * 3, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e * 3, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mha_linear0_weight[rank : rank + 1].cuda(),
mha_linear0_bias[rank : rank + 1].cuda(),
mha_linear1_weight[rank : rank + 1].cuda(),
mha_linear1_bias,
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mlp_linear0_weight[rank : rank + 1].cuda(),
mlp_linear0_bias[rank : rank + 1].cuda(),
mlp_linear1_weight[rank : rank + 1].cuda(),
torch.randn(e, dtype=torch.bfloat16, device="cuda"),
mlp_linear1_bias,
]

fd = TransformerForwardFusion(d, b, s, h, e)
Expand All @@ -530,7 +568,7 @@ def test_transformer_forward(mpi_test):
mha_dropout,
layernorm1_avg,
layernorm1_invstd,
output,
out,
) = outs

# TODO(#2962): validate the numbers as well. Currently, the numbers are off
Expand All @@ -543,12 +581,12 @@ def assert_shape_dtype(

assert_shape_dtype(layernorm0_avg, [b, s], torch.float32)
assert_shape_dtype(layernorm0_invstd, [b, s, 1], torch.float32)
assert_shape_dtype(mha_linear0, [b, s, e * 3], torch.bfloat16)
assert_shape_dtype(sdpa_out, [b, h, s, e // h], torch.bfloat16)
assert_shape_dtype(sdpa_logsum_exp, [b, h, s], torch.float32)
assert_shape_dtype(mha_linear0, [1, b, s, e * 3 // d], torch.bfloat16)
assert_shape_dtype(sdpa_out, [1, b, h // d, s, e // h], torch.bfloat16)
assert_shape_dtype(sdpa_logsum_exp, [1, b, h // d, s], torch.float32)
assert_shape_dtype(sdpa_seed, [], torch.int64)
assert_shape_dtype(sdpa_offset, [], torch.int64)
assert_shape_dtype(mha_dropout, [b, s, e], torch.float32)
assert_shape_dtype(layernorm1_avg, [b, s], torch.float32)
assert_shape_dtype(layernorm1_invstd, [b, s, 1], torch.float32)
assert_shape_dtype(output, [b, s, e], torch.bfloat16)
assert_shape_dtype(out, [b, s, e], torch.bfloat16)

0 comments on commit 66c4bed

Please sign in to comment.