From eb11636e3287e58f8f15fbd80428ef29e36c4efc Mon Sep 17 00:00:00 2001 From: mcowan Date: Thu, 12 Dec 2024 21:15:06 -0800 Subject: [PATCH] add sp test --- tests/cpp/multidevice_transformer.cpp | 62 ++++-- tests/cpp/multidevice_transformer.h | 5 +- tests/cpp/test_multidevice_transformer.cpp | 242 ++++++++------------- 3 files changed, 134 insertions(+), 175 deletions(-) diff --git a/tests/cpp/multidevice_transformer.cpp b/tests/cpp/multidevice_transformer.cpp index dc55b0cffc0..56982a6a693 100644 --- a/tests/cpp/multidevice_transformer.cpp +++ b/tests/cpp/multidevice_transformer.cpp @@ -398,12 +398,15 @@ std::vector DistributedTransformer::mha_backwards( } std::unique_ptr DistributedTransformer::forward( - DataType dtype) { + DataType dtype, + bool sequence_parallel) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); const auto mesh = DeviceMesh::createForNumDevices(D); - TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); + TensorView* x = sequence_parallel + ? makeContigConcreteTensor({D, B * S / D, E}, dtype) + : makeContigConcreteTensor({B * S, E}, dtype); TensorView* ln0_w = makeContigTensor(1); TensorView* ln0_b = makeContigTensor(1); TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); @@ -412,10 +415,10 @@ std::unique_ptr DistributedTransformer::forward( TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); TensorView* ln1_w = makeContigTensor(1); TensorView* ln1_b = makeContigTensor(1); - TensorView* mlp_w0 = makeContigTensor(3, dtype); - TensorView* mlp_b0 = makeContigTensor(2, dtype); - TensorView* mlp_w1 = makeContigTensor(3, dtype); - TensorView* mlp_b1 = makeContigTensor(1, dtype); + TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); + TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); + TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); + TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); fusion->addInput(x); fusion->addInput(ln0_w); @@ -438,27 +441,48 @@ std::unique_ptr DistributedTransformer::forward( auto ln_input = castOp(DataType::Float, x); auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); auto mha_in = castOp(dtype, ln0.output); - auto mha_out = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh).output; - auto resid0 = add(ln_input, mha_out); + auto mha_tvs = + mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, sequence_parallel); + auto resid0 = add(ln_input, mha_tvs.output); auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); auto mlp_in = castOp(dtype, ln1.output); - auto mlp_out = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh).output; - auto resid1 = add(resid0, mlp_out); + auto mlp_tvs = + mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, sequence_parallel); + auto resid1 = add(resid0, mlp_tvs.output); resid1 = castOp(dtype, resid1); fusion->addOutput(ln0.output); - fusion->addOutput(mha_out); + fusion->addOutput(mha_tvs.output); fusion->addOutput(ln1.output); - fusion->addOutput(mlp_out); + fusion->addOutput(mlp_tvs.output); fusion->addOutput(resid1); - - for (auto tv : {x, ln0.output, ln1.output, resid1}) { - tv->setDeviceMesh(mesh); - } - shardBetween({mha_in->definition()}, {mha_out->definition()}, mha_w0); - shardBetween({mlp_in->definition()}, {mlp_out->definition()}, mlp_w0); - shardBetween({x}, {mha_in}, x); + x->setDeviceMesh(mesh); + if (sequence_parallel) { + // Input arrives sharded + x->axis(0)->parallelize(ParallelType::DIDx); + // Propagate SP shardings from x through layernorms, dropouts, residual + // adds. Even though mha_in is part of the boundary set, residuals allow the + // shardings to propagate up the graph so we must cut off the propagation at + // the outputs of reduce scatters (mha and mlp matmul1) + shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); + // Propagate TP sharding for MLP and MHA from sharded weights. We do not + // need to shard from mha_b0 or mlp_b0 because they are only consumed by + // their respective linear0 expression which is sharded from *_w0. + shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); + shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); + shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); + shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); + } else { + // TP only shardings + // Layernorm, residuals, are all replicated like x. shardBetween + // shards all tvs reachable from x, so the input and output tvs must + // be in the boundary set. + shardBetween({x}, {mha_in, mha_tvs.output, mlp_in, mlp_tvs.output}, x); + // TP sharded regions within mha and mlp + shardBetween({mha_in}, {mha_tvs.output}, mha_w0); + shardBetween({mlp_in}, {mlp_tvs.output}, mlp_w0); + } return std::make_unique(std::move(fusion)); } diff --git a/tests/cpp/multidevice_transformer.h b/tests/cpp/multidevice_transformer.h index c0cad26cb99..ed29dd09e36 100644 --- a/tests/cpp/multidevice_transformer.h +++ b/tests/cpp/multidevice_transformer.h @@ -27,7 +27,6 @@ struct MhaResult { TensorView* output; }; - class DistributedTransformer { public: DistributedTransformer( @@ -46,7 +45,9 @@ class DistributedTransformer { kDropoutProb(dropout_prob), kSdpaProb(sdpa_dropout_prob) {} - std::unique_ptr forward(DataType dtype); + std::unique_ptr forward( + DataType dtype, + bool sequence_parallel = false); std::unique_ptr backward(DataType dtype); MlpResult mlp( diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index da67ff3fbd6..6ccb217137f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -18,28 +18,27 @@ namespace nvfuser { - - namespace { - // Note: We test on smaller model and input sizes to avoid high error accumulation - // for validation. - static constexpr int64_t B = 2, E = 768, H = 16, S = 128; - // Note: Dropout probabilities are set to 0. Since the dropout mask is sharded it - // throws off the seed offset between the sharded nvFuser program and the - // unsharded reference. - static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3; - // Note parameters scaled by kParamScale following weight initialization - // recommendations: - // https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range - static constexpr double kParamScale = 0.02; -} +// Note: We test on smaller model and input sizes to avoid high error +// accumulation for validation. +static constexpr int64_t B = 2, E = 768, H = 16, S = 128; +// Note: Dropout probabilities are set to 0. Since the dropout mask is sharded +// it throws off the seed offset between the sharded nvFuser program and the +// unsharded reference. +static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3; +// Note parameters scaled by kParamScale following weight initialization +// recommendations: +// https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range +static constexpr double kParamScale = 0.02; +} // namespace class DistributedTransformerTest : public MultiDeviceTest, public testing::WithParamInterface { protected: DistributedTransformerTest() : D(communicator_->size()) { - model = std::make_unique(D, B, E, H, S, kDropoutProb, kSdpaProb); + model = std::make_unique( + D, B, E, H, S, kDropoutProb, kSdpaProb); } void SetUp() override { @@ -735,144 +734,79 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { expected_outputs, out, {1e-5, 0.02, 1e-5, .01, .02, 0.2, 0.2, 0.2, 0.02}); } -// TEST_P(DistributedTransformerTest, Forward_SP) { -// if (H % D != 0) { -// GTEST_SKIP() << "Requires number of devices=" << D -// << " evenly divide H=" << H; -// } -// if (D == 1) { -// GTEST_SKIP() << "Requires >1 devices, D=" << D; -// } -// auto dtype = GetParam(); -// at::ScalarType at_dtype = data_type_to_aten(dtype); -// auto fusion = std::make_unique(); - -// FusionGuard fg(fusion.get()); -// const auto mesh = DeviceMesh::createForNumDevices(D); - -// TensorView* x = makeContigConcreteTensor({D, B * S / D, E}, dtype); -// TensorView* ln0_w = makeContigTensor(1); -// TensorView* ln0_b = makeContigTensor(1); -// TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); -// TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); -// TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); -// TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); -// TensorView* ln1_w = makeContigTensor(1); -// TensorView* ln1_b = makeContigTensor(1); -// TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); -// TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); -// TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); -// TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); - -// fusion->addInput(x); -// fusion->addInput(ln0_w); -// fusion->addInput(ln0_b); -// fusion->addInput(mha_w0); -// fusion->addInput(mha_b0); -// fusion->addInput(mha_w1); -// fusion->addInput(mha_b1); -// fusion->addInput(ln1_w); -// fusion->addInput(ln1_b); -// fusion->addInput(mlp_w0); -// fusion->addInput(mlp_b0); -// fusion->addInput(mlp_w1); -// fusion->addInput(mlp_b1); - -// constexpr float kEps = 1e-5; -// auto eps = IrBuilder::create(kEps); -// std::vector norm_shape{E}; - -// auto ln_input = castOp(DataType::Float, x); -// auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); -// auto mha_in = castOp(dtype, ln0.output); -// auto mha_tvs = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, true); -// auto resid0 = add(ln_input, mha_tvs.output); -// auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); -// auto mlp_in = castOp(dtype, ln1.output); -// auto mlp_tvs = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, true); -// auto resid1 = add(resid0, mlp_tvs.output); -// resid1 = castOp(dtype, resid1); - -// fusion->addOutput(ln0.output); -// fusion->addOutput(mha_tvs.output); -// fusion->addOutput(ln1.output); -// fusion->addOutput(mlp_tvs.output); -// fusion->addOutput(resid1); - -// x->setDeviceMesh(mesh); -// x->axis(0)->parallelize(ParallelType::DIDx); -// // Propagate SP shardings from x through layernorms, dropouts, residual adds. -// // Even though mha_in is part of the boundary set, residuals allow the -// // shardings to propagate up the graph so we must cut off the propagation at -// // the outputs of reduce scatters (mha and mlp matmul1) -// shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); -// // Propagate TP sharding for MLP and MHA from sharded weights. We do not need -// // to shard from mha_b0 or mlp_b0 because they are only consumed by their -// // respective linear0 expression which is sharded from *_w0. -// shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); -// shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); -// shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); -// shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); - -// const auto options = -// at::TensorOptions().dtype(at_dtype).device(communicator_->device()); -// auto x_ = at::randn({B * S, E}, options); -// auto ln0_w_ = at::randn(E, options).to(at::kFloat); -// auto ln0_b_ = at::randn(E, options).to(at::kFloat); -// auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale; -// auto mha_b0_ = at::randn({3 * E}, options) * kParamScale; -// auto mha_w1_ = at::randn({E, E}, options) * kParamScale; -// auto mha_b1_ = at::randn({E}, options) * kParamScale; -// auto ln1_w_ = at::randn(E, options).to(at::kFloat); -// auto ln1_b_ = at::randn(E, options).to(at::kFloat); -// auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale; -// auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale; -// auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale; -// auto mlp_b1_ = at::randn({E}, options) * kParamScale; - -// at::manual_seed(getATenRandomSeed()); -// auto x_float_ = x_.to(at::kFloat); -// auto ln0_ = at::native_layer_norm(x_float_, norm_shape, ln0_w_, ln0_b_, kEps); -// auto ln0_out_ = std::get<0>(ln0_); - -// auto mha_out_ = reference_mha( -// ln0_out_.to(at_dtype), mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3]; - -// auto resid0_ = mha_out_ + x_float_; -// auto ln1_ = at::native_layer_norm(resid0_, norm_shape, ln1_w_, ln1_b_, kEps); -// auto ln1_out_ = std::get<0>(ln1_); - -// auto mlp_out_ = reference_mlp( -// ln1_out_.to(at_dtype), mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3]; -// auto at_out = (resid0_ + mlp_out_).to(at_dtype); - -// std::vector inputs = { -// shardTensor(x_, 0, mesh).unsqueeze(0), -// ln0_w_, -// ln0_b_, -// shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), -// shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), -// shardTensor(mha_w1_, 1, mesh).unsqueeze(0), -// mha_b1_, -// ln1_w_, -// ln1_b_, -// shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), -// shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), -// shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), -// mlp_b1_}; - -// std::vector expected_outputs = { -// shardTensor(ln0_out_, 0, mesh).unsqueeze(0), -// shardTensor(mha_out_, 0, mesh).unsqueeze(0), -// shardTensor(ln1_out_, 0, mesh).unsqueeze(0), -// shardTensor(mlp_out_, 0, mesh).unsqueeze(0), -// shardTensor(at_out, 0, mesh).unsqueeze(0)}; - -// FusionExecutorCache fec(std::move(fusion)); -// at::manual_seed(getATenRandomSeed()); -// auto outputs = fec.runFusionWithInputs(inputs); -// validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); -// } +TEST_P(DistributedTransformerTest, Forward_SP) { + if (H % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide H=" << H; + } + if (D == 1) { + GTEST_SKIP() << "Requires >1 devices, D=" << D; + } + auto dtype = GetParam(); + at::ScalarType at_dtype = data_type_to_aten(dtype); + const auto mesh = DeviceMesh::createForNumDevices(D); + constexpr float kEps = 1e-5; + std::vector norm_shape{E}; + + const auto options = + at::TensorOptions().dtype(at_dtype).device(communicator_->device()); + auto x_ = at::randn({B * S, E}, options); + auto ln0_w_ = at::randn(E, options).to(at::kFloat); + auto ln0_b_ = at::randn(E, options).to(at::kFloat); + auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale; + auto mha_b0_ = at::randn({3 * E}, options) * kParamScale; + auto mha_w1_ = at::randn({E, E}, options) * kParamScale; + auto mha_b1_ = at::randn({E}, options) * kParamScale; + auto ln1_w_ = at::randn(E, options).to(at::kFloat); + auto ln1_b_ = at::randn(E, options).to(at::kFloat); + auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale; + auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale; + auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale; + auto mlp_b1_ = at::randn({E}, options) * kParamScale; + + at::manual_seed(getATenRandomSeed()); + auto x_float_ = x_.to(at::kFloat); + auto ln0_ = at::native_layer_norm(x_float_, norm_shape, ln0_w_, ln0_b_, kEps); + auto ln0_out_ = std::get<0>(ln0_); + + auto mha_out_ = reference_mha( + ln0_out_.to(at_dtype), mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3]; + + auto resid0_ = mha_out_ + x_float_; + auto ln1_ = at::native_layer_norm(resid0_, norm_shape, ln1_w_, ln1_b_, kEps); + auto ln1_out_ = std::get<0>(ln1_); + + auto mlp_out_ = reference_mlp( + ln1_out_.to(at_dtype), mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3]; + auto at_out = (resid0_ + mlp_out_).to(at_dtype); + + std::vector inputs = { + shardTensor(x_, 0, mesh).unsqueeze(0), + ln0_w_, + ln0_b_, + shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), + shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), + mha_b1_, + ln1_w_, + ln1_b_, + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), + mlp_b1_}; + + std::vector expected_outputs = { + shardTensor(ln0_out_, 0, mesh).unsqueeze(0), + shardTensor(mha_out_, 0, mesh).unsqueeze(0), + shardTensor(ln1_out_, 0, mesh).unsqueeze(0), + shardTensor(mlp_out_, 0, mesh).unsqueeze(0), + shardTensor(at_out, 0, mesh).unsqueeze(0)}; + + auto fec = model->forward(dtype, true); + at::manual_seed(getATenRandomSeed()); + auto outputs = fec->runFusionWithInputs(inputs); + validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); +} TEST_P(DistributedTransformerTest, Forward) { if (H % D != 0) { @@ -1083,4 +1017,4 @@ INSTANTIATE_TEST_SUITE_P( testing::Values(DataType::Half, DataType::BFloat16), testing::PrintToStringParamName()); -} // namespace nvfuser \ No newline at end of file +} // namespace nvfuser