-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor parallel MLP #2360
Tensor parallel MLP #2360
Changes from all commits
3c3db97
6d375ff
3c37a86
c02ddd4
4349452
59f9d88
dcb0d26
e83fd68
92fdee1
a765aa4
b488f04
d7cb419
96935ad
820a5c1
a5fb4b7
ff7afb5
3567125
199651f
f541255
05f239a
9ca55a3
1989691
8f90fb0
c405ee6
81b9138
2b1a236
a1d6232
e8b7f87
e8ac854
b2f2da7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,8 @@ | |
|
||
namespace nvfuser { | ||
|
||
class DistributedMatmulTest : public MultiDeviceTest { | ||
class DistributedMatmulTest : public MultiDeviceTest, | ||
public testing::WithParamInterface<bool> { | ||
protected: | ||
DistributedMatmulTest() : num_devices_(communicator_->size()) {} | ||
|
||
|
@@ -404,4 +405,191 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) { | |
->heuristic(); | ||
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval); | ||
} | ||
|
||
TEST_P(DistributedMatmulTest, MLP_Layer) { | ||
bool use_aten_matmul = GetParam(); | ||
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>(); | ||
FusionGuard fg(fusion.get()); | ||
auto mesh = DeviceMesh::createForNumDevices(communicator_->size()); | ||
|
||
int64_t sb = 64; // sequence * batch | ||
int64_t h = 128; | ||
int64_t h4 = 4 * h; | ||
|
||
// TODO: error with dynamic shape | ||
// C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED | ||
// at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3 | ||
// Exception raised from DynamicTransformConcretizationInfo at | ||
// csrc/dynamic_transform.cpp:276 | ||
TensorView* x = makeContigConcreteTensor({sb, h}, DataType::BFloat16); | ||
TensorView* w0 = makeContigConcreteTensor( | ||
{num_devices_, h4 / num_devices_, h}, DataType::BFloat16); | ||
TensorView* b0 = makeContigConcreteTensor( | ||
{num_devices_, h4 / num_devices_}, DataType::BFloat16); | ||
TensorView* w1 = makeContigConcreteTensor( | ||
{num_devices_, h, h4 / num_devices_}, DataType::BFloat16); | ||
TensorView* b1 = makeContigConcreteTensor({h}, DataType::BFloat16); | ||
fusion->addInput(x); | ||
fusion->addInput(w0); | ||
fusion->addInput(b0); | ||
fusion->addInput(w1); | ||
fusion->addInput(b1); | ||
|
||
// Linear #1 | ||
TensorView* matmul1; | ||
if (use_aten_matmul) { | ||
// TODO: use linear op instead | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @Priya2698 @cowanmeg reminded me of a practical limitation: currently, we split rfactor for DID, so w0 is 3D. |
||
TensorView* w0_t = transpose(w0, 2, 1); | ||
matmul1 = matmul(x, w0_t); | ||
} else { | ||
TensorView* linear_int0 = broadcast(x, {true, false, true, false}); | ||
TensorView* linear_int1 = broadcast(w0, {false, true, false, false}); | ||
TensorView* linear_int2 = mul(linear_int0, linear_int1); | ||
matmul1 = sum(linear_int2, {-1}); | ||
// TODO: linear_int0 has a bcast device axis that the sharding propagation | ||
// pass misses. | ||
linear_int0->setDeviceMesh(mesh); | ||
linear_int0->axis(0)->parallelize(ParallelType::DIDx); | ||
} | ||
TensorView* b0_bcast = broadcast(b0, {false, true, false}); | ||
TensorView* linear1 = add(matmul1, b0_bcast); | ||
|
||
TensorView* linear1_ = castOp(DataType::Float, linear1); | ||
TensorView* gelu = tanh_gelu(linear1_); | ||
TensorView* gelu_ = castOp(DataType::BFloat16, gelu); | ||
|
||
// Linear #2 | ||
TensorView* local_matmul2; | ||
if (use_aten_matmul) { | ||
TensorView* w1_t = transpose(w1, 1, 2); | ||
local_matmul2 = matmul(gelu_, w1_t); | ||
} else { | ||
// segment_set required to ensure the matmul scheduler is called | ||
gelu_ = segment_set(gelu_); | ||
TensorView* linear2_int0 = broadcast(gelu_, {false, false, true, false}); | ||
TensorView* linear2_int1 = broadcast(w1, {false, true, false, false}); | ||
TensorView* linear2_int2 = mul(linear2_int0, linear2_int1); | ||
local_matmul2 = sum(linear2_int2, {-1}); | ||
} | ||
|
||
TensorView* matmul2 = sum(local_matmul2, {0}); // Allreduce | ||
TensorView* bcast_bias = broadcast(b1, {true, false}); | ||
TensorView* linear2 = add(matmul2, bcast_bias); | ||
|
||
// Dropout | ||
// Note: Propagation breaks at rand_like because it creates a fresh TV. | ||
// Temporarily this prevents us from using dropout composite node. | ||
TensorView* linear2_ = castOp(DataType::Float, linear2); | ||
constexpr double kProb = 0.1; | ||
constexpr double kScale = 1.0 / (1.0 - kProb); | ||
Val* philox_seed = fusion->zeroVal(); | ||
Val* philox_offset = fusion->zeroVal(); | ||
TensorView* rand_vals = rand_like(linear2_, philox_seed, philox_offset); | ||
TensorView* mask = lt(rand_vals, IrBuilder::create<Val>(1.0 - kProb)); | ||
TensorView* apply_mask = mul(linear2_, mask); | ||
TensorView* dropout = mul(apply_mask, IrBuilder::create<Val>(kScale)); | ||
|
||
fusion->addOutput(linear1); | ||
fusion->addOutput(gelu); | ||
fusion->addOutput(linear2); | ||
fusion->addOutput(dropout); | ||
|
||
// Manually shard inputs: x, w0, b0, w1, b1 | ||
// outputs: linear1, gelu, linear2, dropout | ||
// TVs where sharding changes: matmul2 | ||
// (TODO) TVs where sharding propagation breaks down: | ||
// linear_int0 = broadcasts where a device dim axis is broadcasted. | ||
// rand_vals => rand_like creates a fresh new TV. | ||
|
||
// TVs replicated on each device. | ||
auto tv_inputs = {x, b1, matmul2, linear2, rand_vals, dropout}; | ||
for (auto tv : tv_inputs) { | ||
tv->setDeviceMesh(mesh); | ||
} | ||
|
||
// TVs sharded on the outermost dimension. | ||
auto tvs = {w0, b0, w1, linear1, gelu, gelu_}; | ||
for (auto tv : tvs) { | ||
tv->setDeviceMesh(mesh); | ||
tv->axis(0)->parallelize(ParallelType::DIDx); | ||
} | ||
|
||
const auto options = at::TensorOptions() | ||
.dtype(c10::ScalarType::BFloat16) | ||
.device(at::kCUDA, communicator_->local_rank()); | ||
auto x_ = at::randn({sb, h}, options); | ||
auto w0_ = at::randn({h4, h}, options); | ||
auto b0_ = at::randn({h4}, options); | ||
auto w1_ = at::randn({h, h4}, options); | ||
auto b1_ = at::randn({h}, options); | ||
|
||
std::vector<c10::IValue> inputs = { | ||
x_, | ||
shardTensor( | ||
w0_.view({num_devices_, h4 / num_devices_, h}), | ||
w0, | ||
communicator_->deviceId()), | ||
shardTensor( | ||
b0_.view({num_devices_, h4 / num_devices_}), | ||
b0, | ||
communicator_->deviceId()), | ||
shardTensor( | ||
w1_.view({h, num_devices_, h4 / num_devices_}).transpose(1, 0), | ||
w1, | ||
communicator_->deviceId()), | ||
b1_}; | ||
at::manual_seed(0); | ||
auto linear1_aten = | ||
at::linear(x_.to(at::kDouble), w0_.to(at::kDouble), b0_.to(at::kDouble)); | ||
auto gelu_aten = at::gelu(linear1_aten.to(at::kFloat), "tanh"); | ||
auto linear2_aten = at::linear( | ||
gelu_aten.to(at::kBFloat16).to(at::kDouble), | ||
w1_.to(at::kDouble), | ||
b1_.to(at::kDouble)); | ||
auto dropout_aten = at::dropout(linear2_aten.to(at::kFloat), kProb, true); | ||
std::vector<at::Tensor> expected_outputs = { | ||
shardTensor( | ||
at::transpose( | ||
linear1_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0), | ||
linear1, | ||
communicator_->deviceId()), | ||
shardTensor( | ||
at::transpose( | ||
gelu_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0), | ||
gelu, | ||
communicator_->deviceId()), | ||
linear2_aten, | ||
dropout_aten}; | ||
|
||
at::manual_seed(0); | ||
MultiDeviceExecutor runtime( | ||
std::move(fusion), *communicator_, executor_params_); | ||
auto outputs = runtime.runWithInput(inputs); | ||
|
||
// Bump up the tolerance - the second matmul carries | ||
// the numerical error from the prior matmul | ||
auto tolerance_overwrite = ValidationConstants(); | ||
std::array<std::array<double, 2>, 20> relaxed_sum_tol; | ||
for (auto& arr : relaxed_sum_tol) { | ||
arr = {128, 3.0}; | ||
} | ||
tolerance_overwrite.sum_tolerances_float = relaxed_sum_tol; | ||
|
||
testValidate( | ||
runtime.completeFusion(), | ||
outputs, | ||
inputs, | ||
expected_outputs, | ||
__LINE__, | ||
__FILE__, | ||
"", | ||
LaunchParams(), | ||
tolerance_overwrite); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P( | ||
, | ||
DistributedMatmulTest, | ||
testing::Bool(), | ||
testing::PrintToStringParamName()); | ||
} // namespace nvfuser |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify this? Are you saying the following code would fail if changed to makeContigTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. This is follow up item (2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's another instance of #2462. Please revisit when it's fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using FusionExecutorCache so sadly that did not fix the error. I did narrow down what is causing the error and opened an issue #2481