-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor parallel MLP #2360
Tensor parallel MLP #2360
Conversation
…ointwise_scheduler
This is not a full reversion of NVIDIA#2419, which also renamed `INPUT_C` and `OUTPUT_D`, and made some modifications to `map{Linear,Matmul}OpIterDomains`. This preserves those changes but allows us to keep distinguishing A and B operands. Fixes NVIDIA#2434
Some followup not addressed by this PR: Line 102 in beb2287
(2) Symbolic TVs: This error is encountered with symbolic TVs C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3 Exception raised from DynamicTransformConcretizationInfo at csrc/dynamic_transform.cpp:276 (3) Improved sharding propagation: TVs where sharding propagation breaks down are (1) broadcasts where a device dim axis is broadcasted (2) rand_like operator because it creates a fresh new TV. The current pass assumes that one of the inputs is already sharded and propagates from producer to consumer. To support this we need to propagate up from the consumer.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost LGTM. Nice work!
int64_t h = 128; | ||
int64_t h4 = 4 * h; | ||
|
||
// TODO: error with dynamic shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify this? Are you saying the following code would fail if changed to makeContigTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. This is follow up item (2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's another instance of #2462. Please revisit when it's fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using FusionExecutorCache so sadly that did not fix the error. I did narrow down what is causing the error and opened an issue #2481
TensorView* gelu_ = castOp(DataType::BFloat16, gelu); | ||
|
||
// Linear #2 | ||
gelu_ = segment_set(gelu_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you comment on why this is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the segment set, the reduction scheduler gets called instead of the matmul scheduler. Can add a comment in the code as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, IIUC we should use matmul
op for both nvfuser matmul and aten matmul in the future and this will address the segmentation issue we see here. We will just need to update our resharding passes to handle matmul and linear ops appropriately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can add a comment in the code as well.
That'd always be helpful. Thank you!
update our resharding passes to handle matmul and linear ops appropriately
I'm surprised using matmul/linear changes resharding at all because they are done locally. What do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh I should clarify, it would be in the insertResharding
pass which automatically adds set
operations where necessary. The work around is very simple, we just add the set
manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SG. I'm fine leaving the segment_set as is. However, in the next (or next next :) PR, we should improve insertResharding to avoid the segment_set. Thunder or nvFuser's Python API doesn't have segment_set and it's too low-level for the framework to add segment_set properly anyway. How does that sound?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use matmul instead of the broadcast+mul-sum no segment_set is needed, so in the next PR we will have a version with no segment set :)
TensorView* linear_int2 = mul(linear_int0, linear_int1); | ||
TensorView* linear_int3 = sum(linear_int2, {-1}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use matmul or linear? cc @Priya2698
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first linear layer can be replaced with a linear op because there is no communication in between the bias add and local matmul. The second linear layer will have to be a matmul since the pattern is matmul + allreduce + bias add.
I do have tests with replacing the second linear layer with matmul, but it was generating very high error. My assumption is that the testValidate
isn't meant to be used when chaining multiple matmuls together? I also fiddled around with the datatypes of the matmul (i.e. bfloat, float, double) but the error was still high.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmul + allreduce + bias add
Is it crazy to do matmul + add half bias + all_reduce
? This may be slightly faster, but we can explore that later.
it was generating very high error
Does the second linear layer in the PR as is map to pointwise+reduction or aten matmul? If the former, we'll need to fix it soon because pointwise+reduction is likely slower. If the latter, I'll be curious why aten matmul via mul+sum is low error than aten matmul via matmul.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmul + add half bias + all_reduce
That's not a bad idea! When we start benchmarking I will try this.
Currently, the second linear layer uses nvfuser matmul. I have uncommitted version of the mlp test that uses aten matmul instead that generated a max error of ~3, while testValidate
had a threshold of ~e-5. It's just a guess, but I don't think I am setting the datatypes correctly in the aten baseline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmul + allreduce + bias add
Also, mainstream (e.g. llama) doesn't seem to use bias. So it might not be relevant.
Currently, the second linear layer uses nvfuser matmul.
SG. In the next PR, let's try to use ATen matmul. This will be how the integration looks like and give the best performance at this moment. Even if the error is slightly higher than the threshold, I'd accept it because (1) we know the error is localized to matmul, and (2) checking in the code allows us to investigate in parallel.
Co-authored-by: Jingyue Wu <[email protected]>
TensorView* linear_int2 = mul(linear_int0, linear_int1); | ||
TensorView* linear_int3 = sum(linear_int2, {-1}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matmul + allreduce + bias add
Is it crazy to do matmul + add half bias + all_reduce
? This may be slightly faster, but we can explore that later.
it was generating very high error
Does the second linear layer in the PR as is map to pointwise+reduction or aten matmul? If the former, we'll need to fix it soon because pointwise+reduction is likely slower. If the latter, I'll be curious why aten matmul via mul+sum is low error than aten matmul via matmul.
TensorView* gelu_ = castOp(DataType::BFloat16, gelu); | ||
|
||
// Linear #2 | ||
gelu_ = segment_set(gelu_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can add a comment in the code as well.
That'd always be helpful. Thank you!
update our resharding passes to handle matmul and linear ops appropriately
I'm surprised using matmul/linear changes resharding at all because they are done locally. What do you mean?
!build |
@wujingyue - I added the MLP test with aten matmul. Note, that the tolerance is bumped up to a bit to pass validation. Validation error in output 0 (linear1) on line 583 in file /tests/cpp/test_multidevice_matmul.cpp. Validation error in output 2 (linear2) on line 583 in file tests/cpp/test_multidevice_matmul.cpp. |
!build |
// Linear #1 | ||
TensorView* matmul1; | ||
if (use_aten_matmul) { | ||
// TODO: use linear op instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @Priya2698
@cowanmeg reminded me of a practical limitation: currently, we split rfactor for DID, so w0 is 3D. linear
(just as torch.linear
) doesn't take a 3D weight. This limitation will eventually go away when split leaf instead of rfactor, but will exist likely for the rest of the year.
Manually sharded tensor parallel multilayer perception layer.
Input is manually translated and sharded mlp layer taken from nanoGPT. See #2199 for where we get the initial compute trace.