Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel MLP #2360

Merged
merged 30 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3c3db97
initial MLP commit
cowanmeg Jun 6, 2024
6d375ff
dropout working
cowanmeg Jun 6, 2024
3c37a86
initial commit, working for outermost DID axis
cowanmeg Jun 7, 2024
c02ddd4
test and support non-outermost DID
cowanmeg Jun 7, 2024
4349452
lint
cowanmeg Jun 7, 2024
59f9d88
feedback
cowanmeg Jun 11, 2024
dcb0d26
Merge branch 'main' of https://github.com/nvidia/Fuser into sharded_p…
cowanmeg Jun 11, 2024
e83fd68
revamp test and fix error
cowanmeg Jun 13, 2024
92fdee1
comment fix
cowanmeg Jun 14, 2024
a765aa4
propagate shardings in mlp
cowanmeg Jun 14, 2024
b488f04
Merge branch 'main' of https://github.com/nvidia/Fuser into mlp
cowanmeg Jun 14, 2024
d7cb419
merge. new double buffer error
cowanmeg Jun 15, 2024
96935ad
temp
cowanmeg Jun 17, 2024
820a5c1
Merge branch 'main' of https://github.com/nvidia/Fuser into mlp
cowanmeg Jun 17, 2024
a5fb4b7
hack fix matmul tensor ordering
cowanmeg Jun 18, 2024
ff7afb5
clean up
cowanmeg Jun 18, 2024
3567125
undo
cowanmeg Jun 18, 2024
199651f
fix expr evaluator error
cowanmeg Jun 18, 2024
f541255
tidy
cowanmeg Jun 18, 2024
05f239a
Split MatmulRole::OPERAND into OPERAND_{A,B}
jacobhinkle Jun 24, 2024
9ca55a3
merge matmul changes and cleanup test
cowanmeg Jun 24, 2024
1989691
lint
cowanmeg Jun 24, 2024
8f90fb0
remove aten for error
cowanmeg Jun 25, 2024
c405ee6
Merge branch 'main' of https://github.com/nvidia/Fuser into mlp
cowanmeg Jun 25, 2024
81b9138
Communicator rename
cowanmeg Jun 25, 2024
2b1a236
Merge branch 'main' of https://github.com/nvidia/Fuser into mlp
cowanmeg Jun 25, 2024
a1d6232
Update tests/cpp/test_multidevice_matmul.cpp
cowanmeg Jun 25, 2024
e8b7f87
feedback
cowanmeg Jun 25, 2024
e8ac854
Add aten matmul mlp
cowanmeg Jun 25, 2024
b2f2da7
Update executor.h comment
cowanmeg Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,14 @@ void ExpressionEvaluator::bind_(
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->getDeviceMesh().size() > 0,
"TV ",
tv->toString(),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
(int)tv->getDeviceMesh().vector().size(),
(int)tv->getDeviceMesh().size(),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
Expand Down
5 changes: 2 additions & 3 deletions csrc/multidevice/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ namespace nvfuser {
parallel type ParallelType::DIDx

We make the following assumptions on the Fusion:
- Only the outmost (non-reduction) axis is allowed to be parallelized
- Only one (non-reduction) axis is allowed to be parallelized
with ParallelType::DIDx. Moreover, this axis cannot be split/merged.
- We only support 1D device meshes for now
- We only support TensorView, not Scalars
- We only support static shapes
- We only support TensorViews in communication segments.

Summary of the different steps performed by the MultiDeviceExecutor:
I. At instantiation:
Expand Down
16 changes: 11 additions & 5 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ void propagateShardings(Fusion* fusion) {
for (auto expr : fusion->exprs()) {
auto inputs = ir_utils::filterByType<TensorView>(expr->inputs());
auto outputs = ir_utils::filterByType<TensorView>(expr->outputs());
if (inputs.empty()) {
continue;
}
TensorView* input_with_mesh = nullptr;
for (auto tv : inputs) {
NVF_CHECK(
Expand Down Expand Up @@ -522,11 +525,14 @@ void unshard(Fusion* fusion) {

std::set<DeviceIdxType> involvedDevices(Expr* expr) {
std::set<DeviceIdxType> ret;
for (const auto& tvs : {expr->inputs(), expr->outputs()}) {
for (auto val : tvs) {
NVF_ERROR(val->isA<TensorView>(), "Val is not a TensorView");
auto tv = val->as<TensorView>();
NVF_ERROR(tv->hasDeviceMesh(), "the TensorView has no device mesh");
for (const auto& tvs :
{ir_utils::filterByType<TensorView>(expr->inputs()),
ir_utils::filterByType<TensorView>(expr->outputs())}) {
for (auto* tv : tvs) {
NVF_ERROR(
tv->hasDeviceMesh(),
"the TensorView has no device mesh: ",
tv->toString());
auto& mesh = tv->getDeviceMesh().vector();
std::copy(mesh.begin(), mesh.end(), std::inserter(ret, ret.end()));
}
Expand Down
12 changes: 12 additions & 0 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ TensorView* TensorView::cacheBefore(LoadStoreOpType op_type) {

consumer->setDomain(replayed_consumer_pair.first);

if (consumer->hasDeviceMesh()) {
producer->setDeviceMesh(consumer->getDeviceMesh());
}

return producer;
}

Expand Down Expand Up @@ -1108,6 +1112,10 @@ TensorView* TensorView::cacheFork() {
IrBuilder::createInContainer<LoadStoreOp>(
container(), LoadStoreOpType::Set, new_output, this);

if (this->hasDeviceMesh()) {
new_output->setDeviceMesh(this->getDeviceMesh());
}

// The new TV becomes an output.
// New TV has global memory type.
// This TV has local memory type.
Expand Down Expand Up @@ -1188,6 +1196,10 @@ TensorView* TensorView::cacheAfter(
// Set domain of producer - No Change
TensorView* producer = this;

if (producer->hasDeviceMesh()) {
consumer->setDeviceMesh(producer->getDeviceMesh());
}

// Insert consumer - Cache_After (CA) - after this TV.
// Before: This TV -> [Use Op] -> Next TV
// After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV
Expand Down
190 changes: 189 additions & 1 deletion tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@

namespace nvfuser {

class DistributedMatmulTest : public MultiDeviceTest {
class DistributedMatmulTest : public MultiDeviceTest,
public testing::WithParamInterface<bool> {
protected:
DistributedMatmulTest() : num_devices_(communicator_->size()) {}

Expand Down Expand Up @@ -404,4 +405,191 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) {
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

TEST_P(DistributedMatmulTest, MLP_Layer) {
bool use_aten_matmul = GetParam();
std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto mesh = DeviceMesh::createForNumDevices(communicator_->size());

int64_t sb = 64; // sequence * batch
int64_t h = 128;
int64_t h4 = 4 * h;

// TODO: error with dynamic shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify this? Are you saying the following code would fail if changed to makeContigTensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. This is follow up item (2).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's another instance of #2462. Please revisit when it's fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using FusionExecutorCache so sadly that did not fix the error. I did narrow down what is causing the error and opened an issue #2481

// C++ exception with description "ext_opt.hasValue() INTERNAL ASSERT FAILED
// at "csrc/dynamic_transform.cpp":276, Could not evaluate dynamic extent: i3
// Exception raised from DynamicTransformConcretizationInfo at
// csrc/dynamic_transform.cpp:276
TensorView* x = makeContigConcreteTensor({sb, h}, DataType::BFloat16);
TensorView* w0 = makeContigConcreteTensor(
{num_devices_, h4 / num_devices_, h}, DataType::BFloat16);
TensorView* b0 = makeContigConcreteTensor(
{num_devices_, h4 / num_devices_}, DataType::BFloat16);
TensorView* w1 = makeContigConcreteTensor(
{num_devices_, h, h4 / num_devices_}, DataType::BFloat16);
TensorView* b1 = makeContigConcreteTensor({h}, DataType::BFloat16);
fusion->addInput(x);
fusion->addInput(w0);
fusion->addInput(b0);
fusion->addInput(w1);
fusion->addInput(b1);

// Linear #1
TensorView* matmul1;
if (use_aten_matmul) {
// TODO: use linear op instead
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Priya2698

@cowanmeg reminded me of a practical limitation: currently, we split rfactor for DID, so w0 is 3D. linear (just as torch.linear) doesn't take a 3D weight. This limitation will eventually go away when split leaf instead of rfactor, but will exist likely for the rest of the year.

TensorView* w0_t = transpose(w0, 2, 1);
matmul1 = matmul(x, w0_t);
} else {
TensorView* linear_int0 = broadcast(x, {true, false, true, false});
TensorView* linear_int1 = broadcast(w0, {false, true, false, false});
TensorView* linear_int2 = mul(linear_int0, linear_int1);
matmul1 = sum(linear_int2, {-1});
// TODO: linear_int0 has a bcast device axis that the sharding propagation
// pass misses.
linear_int0->setDeviceMesh(mesh);
linear_int0->axis(0)->parallelize(ParallelType::DIDx);
}
TensorView* b0_bcast = broadcast(b0, {false, true, false});
TensorView* linear1 = add(matmul1, b0_bcast);

TensorView* linear1_ = castOp(DataType::Float, linear1);
TensorView* gelu = tanh_gelu(linear1_);
TensorView* gelu_ = castOp(DataType::BFloat16, gelu);

// Linear #2
TensorView* local_matmul2;
if (use_aten_matmul) {
TensorView* w1_t = transpose(w1, 1, 2);
local_matmul2 = matmul(gelu_, w1_t);
} else {
// segment_set required to ensure the matmul scheduler is called
gelu_ = segment_set(gelu_);
TensorView* linear2_int0 = broadcast(gelu_, {false, false, true, false});
TensorView* linear2_int1 = broadcast(w1, {false, true, false, false});
TensorView* linear2_int2 = mul(linear2_int0, linear2_int1);
local_matmul2 = sum(linear2_int2, {-1});
}

TensorView* matmul2 = sum(local_matmul2, {0}); // Allreduce
TensorView* bcast_bias = broadcast(b1, {true, false});
TensorView* linear2 = add(matmul2, bcast_bias);

// Dropout
// Note: Propagation breaks at rand_like because it creates a fresh TV.
// Temporarily this prevents us from using dropout composite node.
TensorView* linear2_ = castOp(DataType::Float, linear2);
constexpr double kProb = 0.1;
constexpr double kScale = 1.0 / (1.0 - kProb);
Val* philox_seed = fusion->zeroVal();
Val* philox_offset = fusion->zeroVal();
TensorView* rand_vals = rand_like(linear2_, philox_seed, philox_offset);
TensorView* mask = lt(rand_vals, IrBuilder::create<Val>(1.0 - kProb));
TensorView* apply_mask = mul(linear2_, mask);
TensorView* dropout = mul(apply_mask, IrBuilder::create<Val>(kScale));

fusion->addOutput(linear1);
fusion->addOutput(gelu);
fusion->addOutput(linear2);
fusion->addOutput(dropout);

// Manually shard inputs: x, w0, b0, w1, b1
// outputs: linear1, gelu, linear2, dropout
// TVs where sharding changes: matmul2
// (TODO) TVs where sharding propagation breaks down:
// linear_int0 = broadcasts where a device dim axis is broadcasted.
// rand_vals => rand_like creates a fresh new TV.

// TVs replicated on each device.
auto tv_inputs = {x, b1, matmul2, linear2, rand_vals, dropout};
for (auto tv : tv_inputs) {
tv->setDeviceMesh(mesh);
}

// TVs sharded on the outermost dimension.
auto tvs = {w0, b0, w1, linear1, gelu, gelu_};
for (auto tv : tvs) {
tv->setDeviceMesh(mesh);
tv->axis(0)->parallelize(ParallelType::DIDx);
}

const auto options = at::TensorOptions()
.dtype(c10::ScalarType::BFloat16)
.device(at::kCUDA, communicator_->local_rank());
auto x_ = at::randn({sb, h}, options);
auto w0_ = at::randn({h4, h}, options);
auto b0_ = at::randn({h4}, options);
auto w1_ = at::randn({h, h4}, options);
auto b1_ = at::randn({h}, options);

std::vector<c10::IValue> inputs = {
x_,
shardTensor(
w0_.view({num_devices_, h4 / num_devices_, h}),
w0,
communicator_->deviceId()),
shardTensor(
b0_.view({num_devices_, h4 / num_devices_}),
b0,
communicator_->deviceId()),
shardTensor(
w1_.view({h, num_devices_, h4 / num_devices_}).transpose(1, 0),
w1,
communicator_->deviceId()),
b1_};
at::manual_seed(0);
auto linear1_aten =
at::linear(x_.to(at::kDouble), w0_.to(at::kDouble), b0_.to(at::kDouble));
auto gelu_aten = at::gelu(linear1_aten.to(at::kFloat), "tanh");
auto linear2_aten = at::linear(
gelu_aten.to(at::kBFloat16).to(at::kDouble),
w1_.to(at::kDouble),
b1_.to(at::kDouble));
auto dropout_aten = at::dropout(linear2_aten.to(at::kFloat), kProb, true);
std::vector<at::Tensor> expected_outputs = {
shardTensor(
at::transpose(
linear1_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0),
linear1,
communicator_->deviceId()),
shardTensor(
at::transpose(
gelu_aten.view({sb, num_devices_, h4 / num_devices_}), 1, 0),
gelu,
communicator_->deviceId()),
linear2_aten,
dropout_aten};

at::manual_seed(0);
MultiDeviceExecutor runtime(
std::move(fusion), *communicator_, executor_params_);
auto outputs = runtime.runWithInput(inputs);

// Bump up the tolerance - the second matmul carries
// the numerical error from the prior matmul
auto tolerance_overwrite = ValidationConstants();
std::array<std::array<double, 2>, 20> relaxed_sum_tol;
for (auto& arr : relaxed_sum_tol) {
arr = {128, 3.0};
}
tolerance_overwrite.sum_tolerances_float = relaxed_sum_tol;

testValidate(
runtime.completeFusion(),
outputs,
inputs,
expected_outputs,
__LINE__,
__FILE__,
"",
LaunchParams(),
tolerance_overwrite);
}

INSTANTIATE_TEST_SUITE_P(
,
DistributedMatmulTest,
testing::Bool(),
testing::PrintToStringParamName());
} // namespace nvfuser
Loading