Skip to content

Commit

Permalink
Annotate inputs and outputs of transformer tests (#2858)
Browse files Browse the repository at this point in the history
  • Loading branch information
cowanmeg authored Aug 28, 2024
1 parent 7fdcb3c commit 2eeef46
Showing 1 changed file with 66 additions and 22 deletions.
88 changes: 66 additions & 22 deletions benchmarks/python/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,101 +44,119 @@


def transformer_forward_fusion(fd: FusionDefinition) -> None:
# x: input
T0 = fd.define_tensor(
shape=[1, -1, -1],
contiguity=[None, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[2, 1, 0],
)
# layer_norm0.weight
T1 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# layer_norm0.bias
T2 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA linear0.weight
T3 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MHA linear0.bias
T4 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA linear1.weight
T5 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MHA linear1.bias
T6 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA dropout.rng_offset
S7 = fd.define_scalar(None, dtype=DataType.Int)
# MHA dropout.rng_seed
S8 = fd.define_scalar(None, dtype=DataType.Int)
# layer_norm1.weight
T9 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# layer_norm1.bias
T10 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MLP linear0.weight
T11 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MLP linear0.bias
T12 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MLP linear1.weight
T13 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MLP linear1.bias
T14 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MLP dropout.rng_offset
S15 = fd.define_scalar(None, dtype=DataType.Int)
# MLP dropout.rng_seed
S16 = fd.define_scalar(None, dtype=DataType.Int)

T17 = fd.ops.cast(T0, dtype=DataType.Float)
T18, T19 = fd.ops.var_mean(T17, dims=[2], correction=0, keepdim=False)
S20 = fd.define_scalar(1, dtype=DataType.Int)
Expand Down Expand Up @@ -322,15 +340,15 @@ def transformer_forward_fusion(fd: FusionDefinition) -> None:
T186 = fd.ops.mul(T184, S185)
T187 = fd.ops.add(T113, T186)
T188 = fd.ops.cast(T187, dtype=DataType.BFloat16)
fd.add_output(T19)
fd.add_output(T32)
fd.add_output(T87)
fd.add_output(T88)
fd.add_output(T89)
fd.add_output(T90)
fd.add_output(T115)
fd.add_output(T128)
fd.add_output(T188)
fd.add_output(T19) # layer_norm0.welford_out.avg
fd.add_output(T32) # layer_norm0.invstd
fd.add_output(T87) # MHA sdpa.output
fd.add_output(T88) # MHA sdpa.logsum_exp
fd.add_output(T89) # MHA sdpa.philox_seed
fd.add_output(T90) # MHA sdpa.philox_offset
fd.add_output(T115) # layer_norm1.welford_out.avg
fd.add_output(T128) # layer_norm1.invstd
fd.add_output(T188) # output


def test_transformer_forward(
Expand Down Expand Up @@ -392,145 +410,171 @@ def test_transformer_forward(


def transformer_backward_fusion(fd: FusionDefinition) -> None:
# x: input
T0 = fd.define_tensor(
shape=[1, -1, -1],
contiguity=[None, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[2, 1, 0],
)
# layer_norm0.welford_out.avg
T1 = fd.define_tensor(
shape=[1, -1],
contiguity=[None, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[1, 0],
)
# layer_norm0.invstd
T2 = fd.define_tensor(
shape=[1, -1, 1],
contiguity=[None, True, None],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
# layer_norm0.weight
T3 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# layer_norm0.bias
T4 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA linear0.weight
T5 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MHA linear0.bias
T6 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA sdpa.output
T7 = fd.define_tensor(
shape=[1, -1, -1, -1],
contiguity=[None, True, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[3, 1, 2, 0],
)
# MHA linear1.weight
T8 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MHA linear1.bias
T9 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MHA dropout.rng_offset
S10 = fd.define_scalar(None, dtype=DataType.Int)
# MHA dropout.rng_seed
S11 = fd.define_scalar(None, dtype=DataType.Int)
# layer_norm1.welford_out.avg
T12 = fd.define_tensor(
shape=[1, -1],
contiguity=[None, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[1, 0],
)
# layer_norm1.invstd
T13 = fd.define_tensor(
shape=[1, -1, 1],
contiguity=[None, True, None],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
# layer_norm1.weight
T14 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# layer_norm1.bias
T15 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MLP linear0.weight
T16 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MLP linear0.bias
T17 = fd.define_tensor(
shape=[-1],
contiguity=[True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[0],
)
# MLP dropout.rng_offset
S18 = fd.define_scalar(None, dtype=DataType.Int)
# MLP dropout.rng_seed
S19 = fd.define_scalar(None, dtype=DataType.Int)
# dy: incoming grad
T20 = fd.define_tensor(
shape=[1, -1, -1],
contiguity=[None, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[2, 1, 0],
)
# MLP linear1.weight
T21 = fd.define_tensor(
shape=[-1, -1],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
# MHA sdpa.logsum_exp
T22 = fd.define_tensor(
shape=[1, -1, -1],
contiguity=[None, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
# MHA sdpa.philox_seed
T23 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Int, is_cpu=False)
# MHA sdpa.philox_offset
T24 = fd.define_tensor(shape=[], contiguity=[], dtype=DataType.Int, is_cpu=False)

T25 = fd.ops.cast(T0, dtype=DataType.Float)
S26 = fd.define_scalar(1, dtype=DataType.Int)
S27 = fd.define_scalar(2048, dtype=DataType.Int)
Expand Down Expand Up @@ -979,19 +1023,19 @@ def transformer_backward_fusion(fd: FusionDefinition) -> None:
T454 = fd.ops.add(T384, T453)
T455 = fd.ops.add(T304, T454)
T456 = fd.ops.cast(T455, dtype=DataType.BFloat16)
fd.add_output(T184)
fd.add_output(T186)
fd.add_output(T223)
fd.add_output(T225)
fd.add_output(T228)
fd.add_output(T232)
fd.add_output(T324)
fd.add_output(T326)
fd.add_output(T373)
fd.add_output(T376)
fd.add_output(T379)
fd.add_output(T383)
fd.add_output(T456)
fd.add_output(T184) # MLP linear1.weight_grad
fd.add_output(T186) # MLP linear1.bias_grad
fd.add_output(T223) # MLP linear0.weight_grad
fd.add_output(T225) # MLP linear0.bias_grad
fd.add_output(T228) # layer_norm1.bias_grad
fd.add_output(T232) # layer_norm1.weight_grad
fd.add_output(T324) # MHA linear1.weight_grad
fd.add_output(T326) # MHA linear1.bias_grad
fd.add_output(T373) # MHA linear0.weight_grad
fd.add_output(T376) # MHA linear0.bias_grad
fd.add_output(T379) # layer_norm0.bias_grad
fd.add_output(T383) # layer_norm0.weight_grad
fd.add_output(T456) # dx output grad


def test_transformer_backward(
Expand Down

0 comments on commit 2eeef46

Please sign in to comment.