Skip to content

Commit

Permalink
add sp test
Browse files Browse the repository at this point in the history
  • Loading branch information
cowanmeg committed Dec 13, 2024
1 parent 9852150 commit eb11636
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 175 deletions.
62 changes: 43 additions & 19 deletions tests/cpp/multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,12 +398,15 @@ std::vector<TensorView*> DistributedTransformer::mha_backwards(
}

std::unique_ptr<FusionExecutorCache> DistributedTransformer::forward(
DataType dtype) {
DataType dtype,
bool sequence_parallel) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
const auto mesh = DeviceMesh::createForNumDevices(D);

TensorView* x = makeContigConcreteTensor({B * S, E}, dtype);
TensorView* x = sequence_parallel
? makeContigConcreteTensor({D, B * S / D, E}, dtype)
: makeContigConcreteTensor({B * S, E}, dtype);
TensorView* ln0_w = makeContigTensor(1);
TensorView* ln0_b = makeContigTensor(1);
TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype);
Expand All @@ -412,10 +415,10 @@ std::unique_ptr<FusionExecutorCache> DistributedTransformer::forward(
TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype);
TensorView* ln1_w = makeContigTensor(1);
TensorView* ln1_b = makeContigTensor(1);
TensorView* mlp_w0 = makeContigTensor(3, dtype);
TensorView* mlp_b0 = makeContigTensor(2, dtype);
TensorView* mlp_w1 = makeContigTensor(3, dtype);
TensorView* mlp_b1 = makeContigTensor(1, dtype);
TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype);
TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype);
TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype);
TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype);

fusion->addInput(x);
fusion->addInput(ln0_w);
Expand All @@ -438,27 +441,48 @@ std::unique_ptr<FusionExecutorCache> DistributedTransformer::forward(
auto ln_input = castOp(DataType::Float, x);
auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps);
auto mha_in = castOp(dtype, ln0.output);
auto mha_out = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh).output;
auto resid0 = add(ln_input, mha_out);
auto mha_tvs =
mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, sequence_parallel);
auto resid0 = add(ln_input, mha_tvs.output);
auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps);
auto mlp_in = castOp(dtype, ln1.output);
auto mlp_out = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh).output;
auto resid1 = add(resid0, mlp_out);
auto mlp_tvs =
mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, sequence_parallel);
auto resid1 = add(resid0, mlp_tvs.output);
resid1 = castOp(dtype, resid1);

fusion->addOutput(ln0.output);
fusion->addOutput(mha_out);
fusion->addOutput(mha_tvs.output);
fusion->addOutput(ln1.output);
fusion->addOutput(mlp_out);
fusion->addOutput(mlp_tvs.output);
fusion->addOutput(resid1);

for (auto tv : {x, ln0.output, ln1.output, resid1}) {
tv->setDeviceMesh(mesh);
}

shardBetween({mha_in->definition()}, {mha_out->definition()}, mha_w0);
shardBetween({mlp_in->definition()}, {mlp_out->definition()}, mlp_w0);
shardBetween({x}, {mha_in}, x);
x->setDeviceMesh(mesh);
if (sequence_parallel) {
// Input arrives sharded
x->axis(0)->parallelize(ParallelType::DIDx);
// Propagate SP shardings from x through layernorms, dropouts, residual
// adds. Even though mha_in is part of the boundary set, residuals allow the
// shardings to propagate up the graph so we must cut off the propagation at
// the outputs of reduce scatters (mha and mlp matmul1)
shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x);
// Propagate TP sharding for MLP and MHA from sharded weights. We do not
// need to shard from mha_b0 or mlp_b0 because they are only consumed by
// their respective linear0 expression which is sharded from *_w0.
shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0);
shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1);
shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0);
shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1);
} else {
// TP only shardings
// Layernorm, residuals, are all replicated like x. shardBetween
// shards all tvs reachable from x, so the input and output tvs must
// be in the boundary set.
shardBetween({x}, {mha_in, mha_tvs.output, mlp_in, mlp_tvs.output}, x);
// TP sharded regions within mha and mlp
shardBetween({mha_in}, {mha_tvs.output}, mha_w0);
shardBetween({mlp_in}, {mlp_tvs.output}, mlp_w0);
}

return std::make_unique<FusionExecutorCache>(std::move(fusion));
}
Expand Down
5 changes: 3 additions & 2 deletions tests/cpp/multidevice_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ struct MhaResult {
TensorView* output;
};


class DistributedTransformer {
public:
DistributedTransformer(
Expand All @@ -46,7 +45,9 @@ class DistributedTransformer {
kDropoutProb(dropout_prob),
kSdpaProb(sdpa_dropout_prob) {}

std::unique_ptr<FusionExecutorCache> forward(DataType dtype);
std::unique_ptr<FusionExecutorCache> forward(
DataType dtype,
bool sequence_parallel = false);
std::unique_ptr<FusionExecutorCache> backward(DataType dtype);

MlpResult mlp(
Expand Down
242 changes: 88 additions & 154 deletions tests/cpp/test_multidevice_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,27 @@

namespace nvfuser {



namespace {
// Note: We test on smaller model and input sizes to avoid high error accumulation
// for validation.
static constexpr int64_t B = 2, E = 768, H = 16, S = 128;
// Note: Dropout probabilities are set to 0. Since the dropout mask is sharded it
// throws off the seed offset between the sharded nvFuser program and the
// unsharded reference.
static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3;
// Note parameters scaled by kParamScale following weight initialization
// recommendations:
// https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range
static constexpr double kParamScale = 0.02;
}
// Note: We test on smaller model and input sizes to avoid high error
// accumulation for validation.
static constexpr int64_t B = 2, E = 768, H = 16, S = 128;
// Note: Dropout probabilities are set to 0. Since the dropout mask is sharded
// it throws off the seed offset between the sharded nvFuser program and the
// unsharded reference.
static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3;
// Note parameters scaled by kParamScale following weight initialization
// recommendations:
// https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range
static constexpr double kParamScale = 0.02;
} // namespace

class DistributedTransformerTest
: public MultiDeviceTest,
public testing::WithParamInterface<DataType> {
protected:
DistributedTransformerTest() : D(communicator_->size()) {
model = std::make_unique<DistributedTransformer>(D, B, E, H, S, kDropoutProb, kSdpaProb);
model = std::make_unique<DistributedTransformer>(
D, B, E, H, S, kDropoutProb, kSdpaProb);
}

void SetUp() override {
Expand Down Expand Up @@ -735,144 +734,79 @@ TEST_P(DistributedTransformerTest, MHA_Backward) {
expected_outputs, out, {1e-5, 0.02, 1e-5, .01, .02, 0.2, 0.2, 0.2, 0.02});
}

// TEST_P(DistributedTransformerTest, Forward_SP) {
// if (H % D != 0) {
// GTEST_SKIP() << "Requires number of devices=" << D
// << " evenly divide H=" << H;
// }
// if (D == 1) {
// GTEST_SKIP() << "Requires >1 devices, D=" << D;
// }
// auto dtype = GetParam();
// at::ScalarType at_dtype = data_type_to_aten(dtype);
// auto fusion = std::make_unique<Fusion>();

// FusionGuard fg(fusion.get());
// const auto mesh = DeviceMesh::createForNumDevices(D);

// TensorView* x = makeContigConcreteTensor({D, B * S / D, E}, dtype);
// TensorView* ln0_w = makeContigTensor(1);
// TensorView* ln0_b = makeContigTensor(1);
// TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype);
// TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype);
// TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype);
// TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype);
// TensorView* ln1_w = makeContigTensor(1);
// TensorView* ln1_b = makeContigTensor(1);
// TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype);
// TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype);
// TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype);
// TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype);

// fusion->addInput(x);
// fusion->addInput(ln0_w);
// fusion->addInput(ln0_b);
// fusion->addInput(mha_w0);
// fusion->addInput(mha_b0);
// fusion->addInput(mha_w1);
// fusion->addInput(mha_b1);
// fusion->addInput(ln1_w);
// fusion->addInput(ln1_b);
// fusion->addInput(mlp_w0);
// fusion->addInput(mlp_b0);
// fusion->addInput(mlp_w1);
// fusion->addInput(mlp_b1);

// constexpr float kEps = 1e-5;
// auto eps = IrBuilder::create<Val>(kEps);
// std::vector<int64_t> norm_shape{E};

// auto ln_input = castOp(DataType::Float, x);
// auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps);
// auto mha_in = castOp(dtype, ln0.output);
// auto mha_tvs = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, true);
// auto resid0 = add(ln_input, mha_tvs.output);
// auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps);
// auto mlp_in = castOp(dtype, ln1.output);
// auto mlp_tvs = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, true);
// auto resid1 = add(resid0, mlp_tvs.output);
// resid1 = castOp(dtype, resid1);

// fusion->addOutput(ln0.output);
// fusion->addOutput(mha_tvs.output);
// fusion->addOutput(ln1.output);
// fusion->addOutput(mlp_tvs.output);
// fusion->addOutput(resid1);

// x->setDeviceMesh(mesh);
// x->axis(0)->parallelize(ParallelType::DIDx);
// // Propagate SP shardings from x through layernorms, dropouts, residual adds.
// // Even though mha_in is part of the boundary set, residuals allow the
// // shardings to propagate up the graph so we must cut off the propagation at
// // the outputs of reduce scatters (mha and mlp matmul1)
// shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x);
// // Propagate TP sharding for MLP and MHA from sharded weights. We do not need
// // to shard from mha_b0 or mlp_b0 because they are only consumed by their
// // respective linear0 expression which is sharded from *_w0.
// shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0);
// shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1);
// shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0);
// shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1);

// const auto options =
// at::TensorOptions().dtype(at_dtype).device(communicator_->device());
// auto x_ = at::randn({B * S, E}, options);
// auto ln0_w_ = at::randn(E, options).to(at::kFloat);
// auto ln0_b_ = at::randn(E, options).to(at::kFloat);
// auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale;
// auto mha_b0_ = at::randn({3 * E}, options) * kParamScale;
// auto mha_w1_ = at::randn({E, E}, options) * kParamScale;
// auto mha_b1_ = at::randn({E}, options) * kParamScale;
// auto ln1_w_ = at::randn(E, options).to(at::kFloat);
// auto ln1_b_ = at::randn(E, options).to(at::kFloat);
// auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale;
// auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale;
// auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale;
// auto mlp_b1_ = at::randn({E}, options) * kParamScale;

// at::manual_seed(getATenRandomSeed());
// auto x_float_ = x_.to(at::kFloat);
// auto ln0_ = at::native_layer_norm(x_float_, norm_shape, ln0_w_, ln0_b_, kEps);
// auto ln0_out_ = std::get<0>(ln0_);

// auto mha_out_ = reference_mha(
// ln0_out_.to(at_dtype), mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3];

// auto resid0_ = mha_out_ + x_float_;
// auto ln1_ = at::native_layer_norm(resid0_, norm_shape, ln1_w_, ln1_b_, kEps);
// auto ln1_out_ = std::get<0>(ln1_);

// auto mlp_out_ = reference_mlp(
// ln1_out_.to(at_dtype), mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3];
// auto at_out = (resid0_ + mlp_out_).to(at_dtype);

// std::vector<c10::IValue> inputs = {
// shardTensor(x_, 0, mesh).unsqueeze(0),
// ln0_w_,
// ln0_b_,
// shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
// shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}),
// shardTensor(mha_w1_, 1, mesh).unsqueeze(0),
// mha_b1_,
// ln1_w_,
// ln1_b_,
// shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
// shardTensor(mlp_b0_, 0, mesh).unsqueeze(0),
// shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
// mlp_b1_};

// std::vector<at::Tensor> expected_outputs = {
// shardTensor(ln0_out_, 0, mesh).unsqueeze(0),
// shardTensor(mha_out_, 0, mesh).unsqueeze(0),
// shardTensor(ln1_out_, 0, mesh).unsqueeze(0),
// shardTensor(mlp_out_, 0, mesh).unsqueeze(0),
// shardTensor(at_out, 0, mesh).unsqueeze(0)};

// FusionExecutorCache fec(std::move(fusion));
// at::manual_seed(getATenRandomSeed());
// auto outputs = fec.runFusionWithInputs(inputs);
// validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04});
// }
TEST_P(DistributedTransformerTest, Forward_SP) {
if (H % D != 0) {
GTEST_SKIP() << "Requires number of devices=" << D
<< " evenly divide H=" << H;
}
if (D == 1) {
GTEST_SKIP() << "Requires >1 devices, D=" << D;
}
auto dtype = GetParam();
at::ScalarType at_dtype = data_type_to_aten(dtype);
const auto mesh = DeviceMesh::createForNumDevices(D);
constexpr float kEps = 1e-5;
std::vector<int64_t> norm_shape{E};

const auto options =
at::TensorOptions().dtype(at_dtype).device(communicator_->device());
auto x_ = at::randn({B * S, E}, options);
auto ln0_w_ = at::randn(E, options).to(at::kFloat);
auto ln0_b_ = at::randn(E, options).to(at::kFloat);
auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale;
auto mha_b0_ = at::randn({3 * E}, options) * kParamScale;
auto mha_w1_ = at::randn({E, E}, options) * kParamScale;
auto mha_b1_ = at::randn({E}, options) * kParamScale;
auto ln1_w_ = at::randn(E, options).to(at::kFloat);
auto ln1_b_ = at::randn(E, options).to(at::kFloat);
auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale;
auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale;
auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale;
auto mlp_b1_ = at::randn({E}, options) * kParamScale;

at::manual_seed(getATenRandomSeed());
auto x_float_ = x_.to(at::kFloat);
auto ln0_ = at::native_layer_norm(x_float_, norm_shape, ln0_w_, ln0_b_, kEps);
auto ln0_out_ = std::get<0>(ln0_);

auto mha_out_ = reference_mha(
ln0_out_.to(at_dtype), mha_w0_, mha_b0_, mha_w1_, mha_b1_)[3];

auto resid0_ = mha_out_ + x_float_;
auto ln1_ = at::native_layer_norm(resid0_, norm_shape, ln1_w_, ln1_b_, kEps);
auto ln1_out_ = std::get<0>(ln1_);

auto mlp_out_ = reference_mlp(
ln1_out_.to(at_dtype), mlp_w0_, mlp_b0_, mlp_w1_, mlp_b1_)[3];
auto at_out = (resid0_ + mlp_out_).to(at_dtype);

std::vector<c10::IValue> inputs = {
shardTensor(x_, 0, mesh).unsqueeze(0),
ln0_w_,
ln0_b_,
shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}),
shardTensor(mha_w1_, 1, mesh).unsqueeze(0),
mha_b1_,
ln1_w_,
ln1_b_,
shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
shardTensor(mlp_b0_, 0, mesh).unsqueeze(0),
shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
mlp_b1_};

std::vector<at::Tensor> expected_outputs = {
shardTensor(ln0_out_, 0, mesh).unsqueeze(0),
shardTensor(mha_out_, 0, mesh).unsqueeze(0),
shardTensor(ln1_out_, 0, mesh).unsqueeze(0),
shardTensor(mlp_out_, 0, mesh).unsqueeze(0),
shardTensor(at_out, 0, mesh).unsqueeze(0)};

auto fec = model->forward(dtype, true);
at::manual_seed(getATenRandomSeed());
auto outputs = fec->runFusionWithInputs(inputs);
validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04});
}

TEST_P(DistributedTransformerTest, Forward) {
if (H % D != 0) {
Expand Down Expand Up @@ -1083,4 +1017,4 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(DataType::Half, DataType::BFloat16),
testing::PrintToStringParamName());

} // namespace nvfuser
} // namespace nvfuser

0 comments on commit eb11636

Please sign in to comment.