-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Write a sharded transformer block in nvFuser API. #2199
Comments
Note to myself: I'll first try to get a single-device nvFuser python definition from Thunder, and then we can manually shard it using nvFuser's API. @Priya2698 pointed me to the nv_enable_linear flag (https://github.com/Lightning-AI/lightning-thunder/blob/90a0f4c0d0a90d1e94684a847f3adfe2230985b4/thunder/tests/test_nvfuser.py#L875) that I'll need to turn on to enable prims.linear via nvFuser. I'll probably need to nv_enable_bookend=False as well. |
Note to myself: I'll start with the following benchmark
which exercises one transformer layer in nanoGPT: |
cc @Priya2698
With the following patch diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..4767ab9c 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -2201,6 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
return False
enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+ enable_linear = True
if not enable_linear:
return False
# Verify linear inputs and bias (optional) are supported tensors.
@@ -2210,6 +2211,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
return False
# nvFuser only supports 2D inputs in v0.2.3.
+ import pdb; pdb.set_trace()
if not a.ndim == 2:
return False
return True
The Python definition printed out is unsurprisingly five fusions, none of which have matmul or linear. |
Below is a WAR for the above Thunder check but it ran into an nvFuser issue. diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..137da102 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -4,6 +4,7 @@ from numbers import Number
from typing import Union, List, Any, Optional, Dict, Set, Tuple, Type
from types import NoneType
from collections.abc import Callable, Mapping, Hashable, Sequence
+import math
import os
import time
from copy import copy
@@ -796,7 +797,7 @@ instantiated) this heuristic actually leads to worse code.
enable_bookend: None | bool = get_compile_option("nv_enable_bookend", bookend_help)
# Set default value.
if enable_bookend is None:
- enable_bookend = True
+ enable_bookend = False
assert isinstance(enable_bookend, bool)
if enable_bookend:
@@ -2200,7 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
if nv_version < LooseVersion("0.2.3"):
return False
- enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+ enable_linear = True
if not enable_linear:
return False
# Verify linear inputs and bias (optional) are supported tensors.
@@ -2209,8 +2210,11 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
if bias is not None and not is_supported_tensor(bias):
return False
- # nvFuser only supports 2D inputs in v0.2.3.
- if not a.ndim == 2:
+ if a.ndim < 2:
+ return False
+ if b.ndim != 2:
+ return False
+ if bias.ndim != 1:
return False
return True
@@ -2226,7 +2230,10 @@ def linear(
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)
nvbias = None if bias is None else getnv(bias, fd, lc_to_nv_map)
- return fd.ops.linear(nva, nvb, nvbias)
+
+ nva_2d = fd.ops.reshape(nva, (math.prod(a.shape[:-1]), a.shape[-1]))
+ nvc_2d = fd.ops.linear(nva_2d, nvb, nvbias)
+ return fd.ops.reshape(nvc_2d, a.shape[:-1] + (b.shape[-2],))
register_supported(PrimIDs.LINEAR, linear, _linear_check) import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T5 = fd.ops.cast(T4, dtype=DataType.Float)
T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
S8 = fd.define_scalar(16, dtype=DataType.Int)
S9 = fd.define_scalar(128, dtype=DataType.Int)
S10 = fd.define_scalar(1, dtype=DataType.Int)
V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
S13 = fd.define_scalar(16, dtype=DataType.Int)
S14 = fd.define_scalar(128, dtype=DataType.Int)
S15 = fd.define_scalar(1, dtype=DataType.Int)
V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
T19 = fd.ops.add(T12, S18)
T20 = fd.ops.rsqrt(T19)
S21 = fd.define_scalar(16, dtype=DataType.Int)
S22 = fd.define_scalar(128, dtype=DataType.Int)
S23 = fd.define_scalar(1600, dtype=DataType.Int)
V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
T26 = fd.ops.sub(T5, T25)
S27 = fd.define_scalar(16, dtype=DataType.Int)
S28 = fd.define_scalar(128, dtype=DataType.Int)
S29 = fd.define_scalar(1600, dtype=DataType.Int)
V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
T32 = fd.ops.mul(T26, T31)
S33 = fd.define_scalar(16, dtype=DataType.Int)
S34 = fd.define_scalar(128, dtype=DataType.Int)
S35 = fd.define_scalar(1600, dtype=DataType.Int)
V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
T38 = fd.ops.cast(T37, dtype=DataType.Float)
T39 = fd.ops.mul(T32, T38)
S40 = fd.define_scalar(16, dtype=DataType.Int)
S41 = fd.define_scalar(128, dtype=DataType.Int)
S42 = fd.define_scalar(1600, dtype=DataType.Int)
V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
T45 = fd.ops.cast(T44, dtype=DataType.Float)
T46 = fd.ops.add(T39, T45)
T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
S48 = fd.define_scalar(2048, dtype=DataType.Int)
S49 = fd.define_scalar(1600, dtype=DataType.Int)
V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
T51 = fd.ops.reshape(T47, new_shape=V50)
T52 = fd.ops.linear(T51, T1, T0)
fd.add_output(T52)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
torch.randn((4800,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800,), (1,)),
torch.randn((7680000,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800, 1600), (1600, 1)),
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
]
fd.execute(inputs)
|
FYI,
|
The matmul scheduler failed at Fuser/csrc/scheduler/matmul_utils.cpp Line 275 in 7f7126d
Looks like it assumes both operands to be broadcasted. I'm under the impression that we removed that assumption for #1628. What am I missing? @zasdfgbnm |
FYI, below is the complete fusion after preseg optimizations. The MmaOp is indeed part of the beautiful broadcast+broadcast+mma+add+float2bfloat subgraph, which is good. However, due to other ops in the fusion, this subgraph is not given to the matmul scheduler immediately. Instead, it's decomposed into singletons, and the segmenter has troubles merging them into the expected subgraph.
|
This issue looks related to: #2127. @wujingyue What do you get after #2221? While the ATen evaluation for matmul/linear will drop these assumptions once the new IR nodes are merged, at present, we assume the same in pattern matching as well ( |
For #2199 Broadcasts before Mma are optional. matmul_expr_eval still has problems with this, but I'll file a separate issue for that.
That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment. That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you. |
With some more hacks (which I'll try to find a way to submit), I'm getting some useful nvFusions to hopefully start with. Now the forward pass runs two nvFusions. The first one has one fd.ops.linear, which I suspect is the input linear layer. The second one has three fd.ops.linear, which I suspect is the output linear layer followed by the two-layer MLP. I'll confirm this and try to include SDPA as well.
|
Yes, the new IR nodes will fix this issue since we won't evaluate a decomposed IR. The pattern matching will be redundant and removed once the API is modified to use the new IR nodes. |
@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs? |
It looks like the WAR will still run into the segmentation issue due to the reshapes. If you don't necessarily need that change in thunder to proceed, then adding the new nodes will lift that restriction anyway. I am estimating the new PRs within a couple days earlier next week. We can go ahead with it if it unblocks you in the interim. |
Cool -- I closed Lightning-AI/lightning-thunder#391. |
For NVIDIA/Fuser#2199. To run them, ``` NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py ``` `matmul_expr_eval` is disabled for a known limitation that will be fixed soon. I'll try to include SDPA as well. Currently, the two files implement things before and after SDPA. For your understanding, code around `fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh` corresponds to an approximated GELU layer. Code around `fd.ops.var_mean` corresponds to layernorm.
@cowanmeg Lightning-AI/lightning-thunder@bf84b04 checked in what's in the forward pass of a single-device transformer block modulo SDPA. See the message of that commit for more details. With that, we should be able to work on this in parallel. I'll try to include SPDA and backprop, and you'll try to build a sharded version. How does that sound? |
Thanks @wujingyue! This is super helpful, I'll start working on the sharding soon! |
I annotated the sharding of the MLP layer of the example: https://gist.github.com/cowanmeg/75b4144a3627df74efcfc12dda01a2a3 Some comments: While we discuss our design for (2), I will manually translate these programs and decompose the LinearOp myself. Regardless this is necessary since we need to logically split sharded axes in the compute definition because of our RFactor restriction. For MLP, this isn't too hard and would let us get a small example working. |
For NVIDIA/Fuser#2199. To run them, ``` NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py ``` `matmul_expr_eval` is disabled for a known limitation that will be fixed soon. I'll try to include SDPA as well. Currently, the two files implement things before and after SDPA. For your understanding, code around `fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh` corresponds to an approximated GELU layer. Code around `fd.ops.var_mean` corresponds to layernorm.
FYI, Lightning-AI/lightning-thunder@af6bfc1 added the forward pass of the whole transformer block (i.e. with SDPA). Caveat: the speed is probably far from SOL because nvFuser can't fuse matmul+softmax+matmul at this moment. #2278 is going to add an SDPA IR node so we can fallback to the existing flash attention implementation in ATen. When that's done, we'll see in the fusion definition simply the SDPA node instead of the decomposed form. |
Lightning-AI/lightning-thunder@b06bf4e adds the backprop. It's hard to verify because https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/core/rematerialization.py#L569 mixes in some ops from the forward pass. However, when I try to print the backprop trace before rematerialization (see below), I do see 12
|
@cowanmeg Here's how you can get a Thunder trace to help you understand the backprop nvFusion. The Thunder trace tends to be more concise than nvFusion and has shapes annotated. Also, you can dump the intermediate traces to see where the end trace comes from.
|
FYI, Lightning-AI/lightning-thunder@e19f6ea tries to update the test case to use the GPT-3 config, the one used in the two most recent Megatron papers: https://arxiv.org/pdf/2104.04473 and https://arxiv.org/pdf/2205.05198. It hits #2359 at this moment. |
For NVIDIA/Fuser#2199. To run them, ``` NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py ``` `matmul_expr_eval` is disabled for a known limitation that will be fixed soon. I'll try to include SDPA as well. Currently, the two files implement things before and after SDPA. For your understanding, code around `fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh` corresponds to an approximated GELU layer. Code around `fd.ops.var_mean` corresponds to layernorm.
Manually sharded tensor parallel multilayer perception layer. Input is manually translated and sharded mlp layer taken from nanoGPT. See #2199 for where we get the initial compute trace.
For NVIDIA/Fuser#2199. To run them, ``` NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py ``` `matmul_expr_eval` is disabled for a known limitation that will be fixed soon. I'll try to include SDPA as well. Currently, the two files implement things before and after SDPA. For your understanding, code around `fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh` corresponds to an approximated GELU layer. Code around `fd.ops.var_mean` corresponds to layernorm.
For NVIDIA/Fuser#2199. Both use the GPT-3 sizes to be consistent with the Megatron paper: https://arxiv.org/pdf/2104.04473 With #541 fixed, I saw one fusion for forward, and one for backward. Yay! SDPA is still in decomposed form. Pending on NVIDIA/Fuser#2483 and changes to Thunder's executors.
For #2199. I've been maintaining the nvFusions and the inputs in a branch. This PR checks them into nvFuser's main for convenience.
For #2199. I've been maintaining the nvFusions and the inputs in a branch. This PR checks them into nvFuser's main for convenience. ```shell $ pytest benchmarks/python/test_transformer.py ================================================================================================ test session starts ================================================================================================= platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0 Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type> benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /opt/pytorch/nvfuser plugins: xdist-3.6.1, hypothesis-6.104.2, timestamper-0.0.10, cov-5.0.0, timeout-2.3.1, random-order-1.1.1, benchmark-4.0.0, shard-0.1.2 collected 2 items Running 2 items in this shard benchmarks/python/test_transformer.py .. [100%] --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 54.7508 (1.0) 72.9630 (1.0) 67.7275 (1.0) 4.8153 (1.38) 68.7300 (1.0) 2.0469 (1.01) 2;2 14.7650 (1.0) 10 1 test_transformer_backward 174.6965 (3.19) 187.7991 (2.57) 183.9202 (2.72) 3.4975 (1.0) 184.6459 (2.69) 2.0344 (1.0) 2;1 5.4371 (0.37) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Legend: Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. OPS: Operations Per Second, computed as 1 / Mean ================================================================================================= 2 passed in 8.56s ================================================================================================== ```
For #2199. I've been maintaining the nvFusions and the inputs in a branch. This PR checks them into nvFuser's main for convenience. ```shell $ pytest benchmarks/python/test_transformer.py ================================================================================================ test session starts ================================================================================================= platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0 Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type> benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000) rootdir: /opt/pytorch/nvfuser plugins: xdist-3.6.1, hypothesis-6.104.2, timestamper-0.0.10, cov-5.0.0, timeout-2.3.1, random-order-1.1.1, benchmark-4.0.0, shard-0.1.2 collected 2 items Running 2 items in this shard benchmarks/python/test_transformer.py .. [100%] --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 54.7508 (1.0) 72.9630 (1.0) 67.7275 (1.0) 4.8153 (1.38) 68.7300 (1.0) 2.0469 (1.01) 2;2 14.7650 (1.0) 10 1 test_transformer_backward 174.6965 (3.19) 187.7991 (2.57) 183.9202 (2.72) 3.4975 (1.0) 184.6459 (2.69) 2.0344 (1.0) 2;1 5.4371 (0.37) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- Legend: Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile. OPS: Operations Per Second, computed as 1 / Mean ================================================================================================= 2 passed in 8.56s ================================================================================================== ```
For #2199 ``` $ pytest benchmarks/python/test_transformer.py ``` Before: ``` --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 53.0883 (1.0) 69.7684 (1.0) 65.8204 (1.0) 6.0816 (1.62) 68.9426 (1.0) 4.2709 (2.16) 2;2 15.1929 (1.0) 10 1 test_transformer_backward 174.3857 (3.28) 187.1334 (2.68) 184.6143 (2.80) 3.7561 (1.0) 185.1308 (2.69) 1.9769 (1.0) 1;1 5.4167 (0.36) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` After: ``` --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 53.3807 (1.0) 66.7263 (1.0) 63.6231 (1.0) 3.7131 (1.15) 64.7397 (1.0) 1.0460 (1.0) 1;2 15.7176 (1.0) 10 1 test_transformer_backward 160.4337 (3.01) 171.0229 (2.56) 168.4271 (2.65) 3.2160 (1.0) 169.6143 (2.62) 3.7713 (3.61) 1;1 5.9373 (0.38) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ```
For #2199 Thanks to Lightning-AI/lightning-thunder#951, I'm now able to generate microbenchmarks with SDPA nodes! ``` $ pytest benchmarks/python/test_transformer.py ``` Before: ``` --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 53.0883 (1.0) 69.7684 (1.0) 65.8204 (1.0) 6.0816 (1.62) 68.9426 (1.0) 4.2709 (2.16) 2;2 15.1929 (1.0) 10 1 test_transformer_backward 174.3857 (3.28) 187.1334 (2.68) 184.6143 (2.80) 3.7561 (1.0) 185.1308 (2.69) 1.9769 (1.0) 1;1 5.4167 (0.36) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` After: ``` --------------------------------------------------------------------------------------- benchmark: 2 tests -------------------------------------------------------------------------------------- Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_transformer_forward 53.3807 (1.0) 66.7263 (1.0) 63.6231 (1.0) 3.7131 (1.15) 64.7397 (1.0) 1.0460 (1.0) 1;2 15.7176 (1.0) 10 1 test_transformer_backward 160.4337 (3.01) 171.0229 (2.56) 168.4271 (2.65) 3.2160 (1.0) 169.6143 (2.62) 3.7713 (3.61) 1;1 5.9373 (0.38) 10 1 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ```
This is a fresh dump from the latest https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded. The main differences are: 1. The code size of the fusion is cut in half because shapes are inlined. 2. Two more outputs are added and therefore cached for backprop. For #2199.
All tensors are replicated to all devices to start with. Future PRs will try to shard them. For #2199.
This PR tries to parallelize inputs according to https://arxiv.org/pdf/1909.08053. `propagate_shardings` is able to propagate parallelization to intermediate tensors and outputs. Fixes #2199.
This is to unblock @cowanmeg and @samnordmann 's distributed matmul experiments.
I'll start with the tensor parallelism proposed by the original Megatron-LM paper.
The text was updated successfully, but these errors were encountered: