Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel MLP #2360

Merged
merged 30 commits into from
Jun 26, 2024
Merged

Tensor parallel MLP #2360

merged 30 commits into from
Jun 26, 2024

Conversation

cowanmeg
Copy link
Collaborator

@cowanmeg cowanmeg commented Jun 6, 2024

Manually sharded tensor parallel multilayer perception layer.

Input is manually translated and sharded mlp layer taken from nanoGPT. See #2199 for where we get the initial compute trace.

@cowanmeg cowanmeg marked this pull request as draft June 6, 2024 21:48
@cowanmeg
Copy link
Collaborator Author

cowanmeg commented Jun 24, 2024

Some followup not addressed by this PR:
(1) linear/matmul node: Replace first linear layer with aten linear and second linear layer with aten matmul. Need to update to allow 2D local weight tensors (3D+ with device dims)

TensorView* linear(TensorView* input, TensorView* weight, TensorView* bias) {

(2) Symbolic TVs: This error is encountered with symbolic TVs C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3
Exception raised from DynamicTransformConcretizationInfo at
csrc/dynamic_transform.cpp:276
(3) Improved sharding propagation: TVs where sharding propagation breaks down are (1) broadcasts where a device dim axis is broadcasted (2) rand_like operator because it creates a fresh new TV. The current pass assumes that one of the inputs is already sharded and propagates from producer to consumer. To support this we need to propagate up from the consumer.

@cowanmeg cowanmeg marked this pull request as ready for review June 25, 2024 15:56
@cowanmeg cowanmeg requested a review from wujingyue June 25, 2024 15:56
@cowanmeg cowanmeg requested a review from samnordmann June 25, 2024 15:56
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost LGTM. Nice work!

tests/cpp/test_multidevice_matmul.cpp Outdated Show resolved Hide resolved
int64_t h = 128;
int64_t h4 = 4 * h;

// TODO: error with dynamic shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify this? Are you saying the following code would fail if changed to makeContigTensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is follow up item (2).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's another instance of #2462. Please revisit when it's fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using FusionExecutorCache so sadly that did not fix the error. I did narrow down what is causing the error and opened an issue #2481

TensorView* gelu_ = castOp(DataType::BFloat16, gelu);

// Linear #2
gelu_ = segment_set(gelu_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you comment on why this is needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without the segment set, the reduction scheduler gets called instead of the matmul scheduler. Can add a comment in the code as well.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, IIUC we should use matmul op for both nvfuser matmul and aten matmul in the future and this will address the segmentation issue we see here. We will just need to update our resharding passes to handle matmul and linear ops appropriately.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a comment in the code as well.

That'd always be helpful. Thank you!

update our resharding passes to handle matmul and linear ops appropriately

I'm surprised using matmul/linear changes resharding at all because they are done locally. What do you mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I should clarify, it would be in the insertResharding pass which automatically adds set operations where necessary. The work around is very simple, we just add the set manually.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG. I'm fine leaving the segment_set as is. However, in the next (or next next :) PR, we should improve insertResharding to avoid the segment_set. Thunder or nvFuser's Python API doesn't have segment_set and it's too low-level for the framework to add segment_set properly anyway. How does that sound?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use matmul instead of the broadcast+mul-sum no segment_set is needed, so in the next PR we will have a version with no segment set :)

Comment on lines 440 to 441
TensorView* linear_int2 = mul(linear_int0, linear_int1);
TensorView* linear_int3 = sum(linear_int2, {-1});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use matmul or linear? cc @Priya2698

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first linear layer can be replaced with a linear op because there is no communication in between the bias add and local matmul. The second linear layer will have to be a matmul since the pattern is matmul + allreduce + bias add.
I do have tests with replacing the second linear layer with matmul, but it was generating very high error. My assumption is that the testValidate isn't meant to be used when chaining multiple matmuls together? I also fiddled around with the datatypes of the matmul (i.e. bfloat, float, double) but the error was still high.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul + allreduce + bias add

Is it crazy to do matmul + add half bias + all_reduce? This may be slightly faster, but we can explore that later.

it was generating very high error

Does the second linear layer in the PR as is map to pointwise+reduction or aten matmul? If the former, we'll need to fix it soon because pointwise+reduction is likely slower. If the latter, I'll be curious why aten matmul via mul+sum is low error than aten matmul via matmul.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul + add half bias + all_reduce
That's not a bad idea! When we start benchmarking I will try this.

Currently, the second linear layer uses nvfuser matmul. I have uncommitted version of the mlp test that uses aten matmul instead that generated a max error of ~3, while testValidate had a threshold of ~e-5. It's just a guess, but I don't think I am setting the datatypes correctly in the aten baseline.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul + allreduce + bias add

Also, mainstream (e.g. llama) doesn't seem to use bias. So it might not be relevant.

Currently, the second linear layer uses nvfuser matmul.

SG. In the next PR, let's try to use ATen matmul. This will be how the integration looks like and give the best performance at this moment. Even if the error is slightly higher than the threshold, I'd accept it because (1) we know the error is localized to matmul, and (2) checking in the code allows us to investigate in parallel.

csrc/multidevice/utils.cpp Outdated Show resolved Hide resolved
Comment on lines 440 to 441
TensorView* linear_int2 = mul(linear_int0, linear_int1);
TensorView* linear_int3 = sum(linear_int2, {-1});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matmul + allreduce + bias add

Is it crazy to do matmul + add half bias + all_reduce? This may be slightly faster, but we can explore that later.

it was generating very high error

Does the second linear layer in the PR as is map to pointwise+reduction or aten matmul? If the former, we'll need to fix it soon because pointwise+reduction is likely slower. If the latter, I'll be curious why aten matmul via mul+sum is low error than aten matmul via matmul.

TensorView* gelu_ = castOp(DataType::BFloat16, gelu);

// Linear #2
gelu_ = segment_set(gelu_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add a comment in the code as well.

That'd always be helpful. Thank you!

update our resharding passes to handle matmul and linear ops appropriately

I'm surprised using matmul/linear changes resharding at all because they are done locally. What do you mean?

@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg
Copy link
Collaborator Author

@wujingyue - I added the MLP test with aten matmul. Note, that the tolerance is bumped up to a bit to pass validation.

Validation error in output 0 (linear1) on line 583 in file /tests/cpp/test_multidevice_matmul.cpp.
Detected abs error of: 0.122498
absolute tolerance was set to 0.005
and relative tolerance set to 5e-05

Validation error in output 2 (linear2) on line 583 in file tests/cpp/test_multidevice_matmul.cpp.
Detected abs error of: 4.08847
absolute tolerance was set to 2
and relative tolerance set to 0.02

@cowanmeg
Copy link
Collaborator Author

!build

@cowanmeg cowanmeg merged commit 1ed0e86 into NVIDIA:main Jun 26, 2024
5 checks passed
// Linear #1
TensorView* matmul1;
if (use_aten_matmul) {
// TODO: use linear op instead
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Priya2698

@cowanmeg reminded me of a practical limitation: currently, we split rfactor for DID, so w0 is 3D. linear (just as torch.linear) doesn't take a 3D weight. This limitation will eventually go away when split leaf instead of rfactor, but will exist likely for the rest of the year.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants