diff --git a/CMakeLists.txt b/CMakeLists.txt index 040d8129b72..6ff2230fe03 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -644,6 +644,7 @@ if(BUILD_TEST) set(MULTIDEVICE_TEST_SRCS) list(APPEND MULTIDEVICE_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/multidevice.cpp + ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_overlap.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communications.cpp ${NVFUSER_ROOT}/tests/cpp/test_multidevice_communicator.cpp @@ -776,6 +777,43 @@ if(BUILD_NVFUSER_BENCHMARK) -Werror -Wno-deprecated-copy ) endif() + + # multidevice transformer benchmark + if (NVFUSER_DISTRIBUTED) + set(MULTIDEVICE_BENCHMARK_SRCS) + list(APPEND MULTIDEVICE_BENCHMARK_SRCS + ${NVFUSER_ROOT}/benchmarks/cpp/transformer.cpp + ${NVFUSER_ROOT}/tests/cpp/multidevice_transformer.cpp + ${NVFUSER_ROOT}/tests/cpp/utils.cpp + ) + add_executable(nvfuser_multidevice_bench ${MULTIDEVICE_BENCHMARK_SRCS}) + set_target_properties(nvfuser_multidevice_bench PROPERTIES + C_STANDARD ${NVFUSER_C_STANDARD} + CUDA_STANDARD ${NVFUSER_CUDA_STANDARD} + CXX_STANDARD ${NVFUSER_CPP_STANDARD} + CXX_STANDARD_REQUIRED ON + CXX_VISIBILITY_PRESET hidden + POSITION_INDEPENDENT_CODE Yes + VISIBILITY_INLINES_HIDDEN Yes + ) + target_include_directories(nvfuser_multidevice_bench SYSTEM PRIVATE + ${CMAKE_SOURCE_DIR}/third_party/benchmark/include + ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include + ${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include + ) + target_include_directories(nvfuser_multidevice_bench PUBLIC ${NVFUSER_ROOT}) + target_link_libraries(nvfuser_multidevice_bench PRIVATE + benchmark::benchmark + codegen_internal + ) + add_dependencies(nvfuser_multidevice_bench flatc build_flatbuffer_config) + if(NOT MSVC) + target_compile_options(nvfuser_bench PRIVATE + -Wall -Wno-unused-function + -Werror -Wno-deprecated-copy + ) + endif() + endif() endif() # --- generate runtime files diff --git a/benchmarks/cpp/transformer.cpp b/benchmarks/cpp/transformer.cpp new file mode 100644 index 00000000000..b27f9ec43ff --- /dev/null +++ b/benchmarks/cpp/transformer.cpp @@ -0,0 +1,231 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace nvfuser; + +constexpr int64_t B = 1, E = 12288, H = 96, S = 2048; +constexpr double kParamScale = 0.02; +constexpr int64_t warmup_itrs = 10, num_itrs = 10; + +at::Tensor shardTensor( + at::Tensor tensor, + int64_t axis, + const DeviceMesh& mesh, + Communicator* communicator_) { + const auto device_id = communicator_->deviceId(); + auto i = mesh.idxOf(device_id); + auto extent = tensor.size(axis); + auto nslices = mesh.size(); + NVF_CHECK( + extent % nslices == 0, "Sharded axis must be evenly divisble by mesh"); + auto stride = extent / nslices; + i = (i < 0) ? 0 : i; + auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); + // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx + // axis in front representing the sharded extent of the tensor. + if (stride > 1) { + slice = slice.unsqueeze(0); + } + return slice; +} + +void forward_transformer(Communicator* communicator_, bool profile) { + int64_t D = communicator_->size(); + auto dtype = DataType::BFloat16; + at::ScalarType at_dtype = data_type_to_aten(dtype); + const auto mesh = DeviceMesh::createForNumDevices(D); + const auto options = + at::TensorOptions().dtype(at_dtype).device(communicator_->device()); + + auto x_ = at::randn({B * S, E}, options).to(at::kFloat); + auto ln0_w_ = at::randn(E, options).to(at::kFloat); + auto ln0_b_ = at::randn(E, options).to(at::kFloat); + auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale; + auto mha_b0_ = at::randn({3 * E}, options) * kParamScale; + auto mha_w1_ = at::randn({E, E}, options) * kParamScale; + auto mha_b1_ = at::randn({E}, options) * kParamScale; + auto ln1_w_ = at::randn(E, options).to(at::kFloat); + auto ln1_b_ = at::randn(E, options).to(at::kFloat); + auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale; + auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale; + auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale; + auto mlp_b1_ = at::randn({E}, options) * kParamScale; + + std::vector inputs = { + x_, + ln0_w_, + ln0_b_, + shardTensor(mha_w0_.view({3, E, E}), 1, mesh, communicator_) + .view({1, 3 * E / D, E}), + shardTensor(mha_b0_.view({3, E}), 1, mesh, communicator_) + .view({1, 3 * E / D}), + shardTensor(mha_w1_, 1, mesh, communicator_), + mha_b1_, + ln1_w_, + ln1_b_, + shardTensor(mlp_w0_, 0, mesh, communicator_), + shardTensor(mlp_b0_, 0, mesh, communicator_), + shardTensor(mlp_w1_, 1, mesh, communicator_), + mlp_b1_}; + + DistributedTransformer model = DistributedTransformer(D, B, E, H, S); + auto fec = model.forward(dtype); + cudaSetDevice(communicator_->deviceId()); + + auto start = std::chrono::high_resolution_clock::now(); + for (auto i : c10::irange(num_itrs + warmup_itrs)) { + if (i == warmup_itrs) { + start = std::chrono::high_resolution_clock::now(); + if (profile) { + cudaProfilerStart(); + } + } + if (i >= warmup_itrs && profile) { + nvtxRangePush(("Iteration" + std::to_string(i)).c_str()); + } + auto outputs = fec->runFusionWithInputs(inputs); + cudaDeviceSynchronize(); + // cudaDeviceSynchronize is not blocking until kernels are finished on all + // devices except 0 + // TODO: are we not waiting until all kernels are appended to the stream? + std::cout << outputs[0][0][0] << std::endl; + + if (i > warmup_itrs && profile) { + nvtxRangePop(); + } + } + auto end = std::chrono::high_resolution_clock::now(); + + double foward_time = + std::chrono::duration_cast(end - start) + .count() / + (double)num_itrs / 1000.0; + std::cout << communicator_->deviceId() << ": Average forward time " + << foward_time << "ms" << std::endl; +} + +void backward_transformer(Communicator* communicator_, bool profile) { + auto dtype = DataType::BFloat16; + at::ScalarType at_dtype = data_type_to_aten(dtype); + int64_t D = communicator_->size(); + const auto mesh = DeviceMesh::createForNumDevices(D); + + const auto options = + at::TensorOptions().dtype(at_dtype).device(communicator_->device()); + auto x_ = at::randn({B * S, E}, options).to(at::kFloat); + auto ln0_w_ = at::randn(E, options).to(at::kFloat); + auto ln0_b_ = at::randn(E, options).to(at::kFloat); + auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale; + auto mha_b0_ = at::randn({3 * E}, options) * kParamScale; + auto mha_w1_ = at::randn({E, E}, options) * kParamScale; + auto mha_b1_ = at::randn({E}, options) * kParamScale; + auto ln1_w_ = at::randn(E, options).to(at::kFloat); + auto ln1_b_ = at::randn(E, options).to(at::kFloat); + auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale; + auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale; + auto grad_ = at::randn({B * S, E}, options).to(at::kFloat) * kParamScale; + auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale; + auto mlp_b1_ = at::randn({E}, options) * kParamScale; + + // Recomputed tensors + auto mlp_dropout_mask = at::rand({B * S, E}, options).lt(1.0 - 0.1); + auto mha_dropout_mask = at::rand({B * S, E}, options).lt(1.0 - 0.1); + auto sdpa_output = at::randn({B, H, S, E / H}, options); + auto sdpa_logsum_exp = at::randn({B, H, S}, options).to(at::kFloat); + auto sdpa_seed = at::scalar_tensor(1, at::kLong); + auto sdpa_offset = at::scalar_tensor(1, at::kLong); + auto ln0_mean = at::randn({B * S, 1}, options).to(at::kFloat); + auto ln0_rstd = at::randn({B * S, 1}, options).to(at::kFloat); + auto ln1_mean = at::randn({B * S, 1}, options).to(at::kFloat); + auto ln1_rstd = at::randn({B * S, 1}, options).to(at::kFloat); + auto mha_linear1 = at::rand({B * S, E}, options).to(at::kFloat); + + std::vector inputs = { + x_, + grad_, + shardTensor(mha_w0_.view({3, E, E}), 1, mesh, communicator_) + .view({1, 3 * E / D, E}), + shardTensor(mha_b0_.view({3, E}), 1, mesh, communicator_) + .view({1, 3 * E / D}), + shardTensor(mha_w1_, 1, mesh, communicator_), + shardTensor(mlp_w0_, 0, mesh, communicator_), + shardTensor(mlp_b0_, 0, mesh, communicator_), + shardTensor(mlp_w1_, 1, mesh, communicator_), + mlp_b1_, + mlp_dropout_mask, + mha_dropout_mask, + shardTensor(sdpa_output, 1, mesh, communicator_), + shardTensor(sdpa_logsum_exp, 1, mesh, communicator_), + sdpa_seed, + sdpa_offset, + ln1_w_, + ln1_b_, + ln1_mean, + ln1_rstd, + ln0_w_, + ln0_b_, + ln0_mean, + ln0_rstd, + mha_linear1}; + + DistributedTransformer model = DistributedTransformer(D, B, E, H, S); + auto fec = model.backward(dtype); + std::vector outputs; + + cudaSetDevice(communicator_->deviceId()); + auto start = std::chrono::high_resolution_clock::now(); + for (auto i : c10::irange(num_itrs + warmup_itrs)) { + if (i == warmup_itrs) { + start = std::chrono::high_resolution_clock::now(); + } + if (i >= warmup_itrs && profile) { + nvtxRangePush(("Iteration" + std::to_string(i)).c_str()); + } + outputs = fec->runFusionWithInputs(inputs); + cudaDeviceSynchronize(); + // cudaDeviceSynchronize is not blocking until kernels are finished on all + // devices except 0 + // TODO: are we not waiting until all kernels are appended to the stream? + std::cout << outputs[0][0][0][0] << std::endl; + + if (i > warmup_itrs && profile) { + nvtxRangePop(); + } + } + auto end = std::chrono::high_resolution_clock::now(); + if (profile) { + cudaProfilerStop(); + } + + double backward_time = + std::chrono::duration_cast(end - start) + .count() / + (double)num_itrs / 1000.0; + std::cout << communicator_->deviceId() << ": Average backward time " + << backward_time << "ms" << std::endl; +} + +int main(int argc, char** argv) { + // using this is as a flag for when to profile + bool profile = argc > 1; + auto communicator_ = &Communicator::getInstance(); + forward_transformer(communicator_, profile); + communicator_->barrier(); + backward_transformer(communicator_, profile); +} diff --git a/tests/cpp/multidevice_transformer.cpp b/tests/cpp/multidevice_transformer.cpp new file mode 100644 index 00000000000..56982a6a693 --- /dev/null +++ b/tests/cpp/multidevice_transformer.cpp @@ -0,0 +1,665 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include + +namespace nvfuser { +namespace { +// TODO: These linear_backwards helper functions can be merged once +// we do not have logically split rfactor domain. +struct LinearBackwardsResult { + TensorView* grad_x; + TensorView* grad_w; + TensorView* grad_b; +}; + +// x format: [i0, i1] dtype +// weight format: [DID(D), i2/D, i1] dtype +// grad format: [DID(D) i0, i2/D] float or dtype +// outputs: grad_x [i0, i1] dtype +// grad_w [DID i2/D, i1] dtype +// grad_b [DID i2/2] dtype +LinearBackwardsResult linear_backwards( + TensorView* x, + TensorView* w, + TensorView* grad) { + DataType dtype = w->dtype(); + TensorView* grad_f = maybeCastOp(DataType::Float, grad); + TensorView* grad_q = maybeCastOp(dtype, grad); + TensorView* grad_x_partials = matmul(grad_q, w); + TensorView* grad_x = sum(grad_x_partials, {0}); // allreduce + TensorView* grad_q_t = transpose(grad_q, 1, 2); + TensorView* grad_w = matmul(grad_q_t, x); + TensorView* grad_b = sum(grad_f, {1}); + grad_b = castOp(dtype, grad_b); + + return {grad_x, grad_w, grad_b}; +} + +// x format: [DID, i0, i1/D] dtype +// weight format: [DID, i2, i1/D] dtype +// grad format: [i0, i2] float +// outputs: grad_x [DID i0, i1/D] dtype +// grad_w [DID, i2, i1/D] dtype +// grad_b [i2] dtype +LinearBackwardsResult sharded_linear_backwards( + TensorView* x, + TensorView* w, + TensorView* grad) { + DataType dtype = w->dtype(); + TensorView* grad_q = castOp(dtype, grad); + TensorView* grad_x = matmul(grad_q, w); + TensorView* grad_t = transpose(grad_q, 0, 1); + TensorView* grad_w = matmul(grad_t, x); + TensorView* grad_b = sum(grad, {0}); + grad_b = castOp(dtype, grad_b); + + return {grad_x, grad_w, grad_b}; +} + +// Forward layer_norm with cached mean_bcast and invstd tensors to avoid +// recomputing Welford. For use in backwards pass. +TensorView* layer_norm_with_cached_statistics( + TensorView* x, + TensorView* mean_bcast, + TensorView* invstd, + const std::vector& norm_shape, + TensorView* weight, + TensorView* bias) { + const int64_t kNumberOfDims = + (int64_t)TensorDomain::noReductions(x->getLogicalDomain()).size(); + const int64_t kOuterNumDims = kNumberOfDims - norm_shape.size(); + std::vector outer_broadcast_mask(kNumberOfDims, false); + for (const auto idx : c10::irange(kOuterNumDims)) { + outer_broadcast_mask[idx] = true; + } + + auto x_sub_mean = sub(x, mean_bcast); + auto y = mul(x_sub_mean, invstd); + + auto weight_bcast = broadcast(weight, outer_broadcast_mask); + y = mul(y, weight_bcast); + auto bias_bcast = broadcast(bias, outer_broadcast_mask); + return add(y, bias_bcast); +} +} // namespace + +MlpResult DistributedTransformer::mlp( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel) { + const DataType dtype = w0->dtype(); + + if (sequence_parallel) { + // Input arrives sharded and must be allgathered back + x->setDeviceMesh(mesh); + x->axis(0)->parallelize(ParallelType::DIDx); + x = set(x); // allgather + x->axis(0)->parallelize(ParallelType::Serial); + // Reshape back to 2D. This is uncessary except to keep + // the shapes of linear0 the same for TP and TP+SP. + auto D = w0->axis(0)->extent()->value().as(); + x = reshape(x, {D, B * S / D, E}, {B * S, E}); + } + // Linear 0 + TensorView* linear0 = linear(x, w0, b0); + // GeLU + TensorView* gelu = tanh_gelu(castOp(DataType::Float, linear0)); + gelu = castOp(dtype, gelu); + // Linear 1 + TensorView* local_matmul1 = matmul(gelu, transpose(w1, 1, 2)); + if (sequence_parallel) { + // Remove after https://github.com/NVIDIA/Fuser/issues/2563 + // Reshape to explicitly pull the sharded axis into the logical domain + auto D = w0->axis(0)->extent()->value().as(); + local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); + } + TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter + std::vector bcast_mask(matmul1->nDims() - 1, true); + bcast_mask[matmul1->nDims() - 2] = false; + TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); + // Dropout + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + TensorView* dropout_result = dropout(linear1, prob, scale).output; + + // Tensor parallel shardings + for (auto* tv : {w0, b0, w1}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + for (auto* tv : {x, b1}) { + tv->setDeviceMesh(mesh); + } + + // Sequence parallel shardings + if (sequence_parallel) { + matmul1->setDeviceMesh(mesh); + matmul1->axis(1)->parallelize(ParallelType::DIDx); + } + + return {linear0, gelu, matmul1, linear1, dropout_result}; +} + +MhaResult DistributedTransformer::mha( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel) { + const auto D = w0->axis(0)->extent()->value().as(); + auto dtype = w0->dtype(); + + if (sequence_parallel) { + // Input arrives sharded and must be allgathered back + x->setDeviceMesh(mesh); + x->axis(0)->parallelize(ParallelType::DIDx); + x = set(x); // allgather + x->axis(0)->parallelize(ParallelType::Serial); + // Reshape is uncessary, it is here to keep shapes with TP and TP+SP the + // same for validation. + x = reshape(x, {D, B * S / D, E}, {B * S, E}); + } + + TensorView* linear0 = linear(x, w0, b0); + // Forming the q,k,v vectors: + TensorView* qkv_cat = + reshape(linear0, {D, B * S, 3 * E / D}, {D, B, S, 3 * E / D}); + std::vector qkv = chunk(qkv_cat, 3, -1); + for (auto i : c10::irange(3)) { + qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); + qkv[i] = transpose(qkv[i], 2, 3); + } + // SDPA + SdpfaFwdResult sdpa = sdpfa_fwd( + qkv[0], + qkv[1], + qkv[2], + IrBuilder::create(kSdpaProb), + IrBuilder::create(true), + IrBuilder::create(kSdpaScale)); + TensorView* sdpa_output = sdpa.output; + // Linear 1 + TensorView* sdpa_transpose = transpose(sdpa_output, 2, 3); + TensorView* sdpa_reshape = + reshape(sdpa_transpose, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* local_matmul1 = matmul(sdpa_reshape, transpose(w1, 1, 2)); + if (sequence_parallel) { + // Remove after https://github.com/NVIDIA/Fuser/issues/2563 + // Reshape to explicitly pull the sharded axis into the logical domain + auto D = w0->axis(0)->extent()->value().as(); + local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); + } + TensorView* matmul1 = sum(local_matmul1, {0}); // allreduce + std::vector bcast_mask(matmul1->nDims() - 1, true); + bcast_mask[matmul1->nDims() - 2] = false; + TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); + // Dropout + Val* prob = IrBuilder::create(1.0 - kDropoutProb); + Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + TensorView* dropout_result = dropout(linear1, prob, scale).output; + + // Tensor parallel shardings + for (auto tv : {x, b1}) { + tv->setDeviceMesh(mesh); + } + for (auto tv : {w0, b0, w1}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + // Sequence parallel sharding. + if (sequence_parallel) { + matmul1->setDeviceMesh(mesh); + matmul1->axis(1)->parallelize(ParallelType::DIDx); + } + + return {linear0, sdpa_output, matmul1, linear1, dropout_result}; +} + +std::vector DistributedTransformer::mlp_backwards( + TensorView* grad, + TensorView* x, + TensorView* mask, + TensorView* w0, + TensorView* w1, + TensorView* linear0, + const DeviceMesh& mesh) { + DataType dtype = w0->dtype(); + + // Activation recomputation: Always recompute gelu + TensorView* gelu = castOp(dtype, tanh_gelu(castOp(DataType::Float, linear0))); + + // Backwards pass + const double kScale = 1.0 / (1.0 - kDropoutProb); + Val* dropout_scale = IrBuilder::create(kScale); + TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); + auto linear1_grads = sharded_linear_backwards(gelu, w1, dropout_grad); + TensorView* matmul1_grad_x_ = castOp(DataType::Float, linear1_grads.grad_x); + TensorView* gelu_grad = tanh_gelu_backward(matmul1_grad_x_, linear0); + auto linear0_grads = linear_backwards(x, w0, gelu_grad); + + // Manaul sharding annotations + for (auto tv : + {x, + grad, + mask, + dropout_grad, + linear1_grads.grad_b, + linear0_grads.grad_x}) { + tv->setDeviceMesh(mesh); + } + + for (auto tv : + {w0, + w1, + linear0, + linear1_grads.grad_x, + linear1_grads.grad_w, + gelu_grad, + linear0_grads.grad_w, + linear0_grads.grad_b}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + + std::vector outputs = { + dropout_grad, + linear1_grads.grad_w, + linear1_grads.grad_b, + gelu_grad, + linear0_grads.grad_w, + linear0_grads.grad_b, + linear0_grads.grad_x}; + return outputs; +} + +std::vector DistributedTransformer::mha_backwards( + TensorView* x, + TensorView* w0, + TensorView* w1, + TensorView* mask, + TensorView* sdpa_output, + TensorView* sdpa_log_sumexp, + TensorView* sdpa_seed, + TensorView* sdpa_offset, + TensorView* grad, + TensorView* linear0, + const DeviceMesh& mesh) { + DataType dtype = w0->dtype(); + const auto D = w0->axis(0)->extent()->value().as(); + // Reform qkv from linear0 output + TensorView* qkv_cat = reshape( + castOp(DataType::Float, linear0), + {D, B * S, 3 * E / D}, + {D, B, S, 3 * E / D}); + std::vector qkv = chunk(qkv_cat, 3, -1); + for (auto i : c10::irange(3)) { + qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); + qkv[i] = transpose(qkv[i], 2, 3); + qkv[i] = castOp(dtype, qkv[i]); + qkv[i]->setDeviceMesh(mesh); + qkv[i]->axis(0)->parallelize(ParallelType::DIDx); + } + + // dropout backwards + const double kScale = 1.0 / (1.0 - kDropoutProb); + auto dropout_scale = IrBuilder::create(kScale); + TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); + + // linear1 backwards + TensorView* sdpa_output_reshape = + transpose(sdpa_output, 2, 3); // D, B, S, H/D, E/H + sdpa_output_reshape = + reshape(sdpa_output_reshape, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + auto linear1_grads = + sharded_linear_backwards(sdpa_output_reshape, w1, dropout_grad); + + // SDPA backwards + TensorView* linear1_x_grad = + reshape(linear1_grads.grad_x, {D, B * S, E / D}, {D, B, S, H / D, E / H}); + linear1_x_grad = transpose(linear1_x_grad, 2, 3); // D, B, H/D, S, E/H + // Explicitly shard inputs before SDPA backward node + for (auto tv : {linear1_x_grad, sdpa_output, sdpa_log_sumexp}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + auto sdpa_grad = sdpfa_bwd( + linear1_x_grad, + qkv[0], + qkv[1], + qkv[2], + sdpa_output, + sdpa_log_sumexp, + /*dropout_p=*/IrBuilder::create(kSdpaProb), + /*is_causal=*/IrBuilder::create(true), + sdpa_seed, + sdpa_offset, + /*scale=*/IrBuilder::create(kSdpaScale)); + + TensorView* q_grad = transpose(sdpa_grad.grad_query, 2, 3); + q_grad = reshape(q_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* v_grad = transpose(sdpa_grad.grad_value, 2, 3); + v_grad = reshape(v_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* k_grad = transpose(sdpa_grad.grad_key, 2, 3); + k_grad = reshape(k_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); + TensorView* kqv_grad = cat({k_grad, q_grad, v_grad}, -1); + auto linear0_grads = linear_backwards(x, w0, kqv_grad); + + for (auto tv : + {x, + mask, + grad, + dropout_grad, + linear1_grads.grad_b, + linear0_grads.grad_x}) { + tv->setDeviceMesh(mesh); + } + for (auto tv : + {w0, + w1, + sdpa_output, + sdpa_log_sumexp, + linear0, + linear1_grads.grad_x, + linear1_grads.grad_w, + linear0_grads.grad_w, + linear0_grads.grad_b, + sdpa_grad.grad_query, + sdpa_grad.grad_key, + sdpa_grad.grad_value}) { + tv->setDeviceMesh(mesh); + tv->axis(0)->parallelize(ParallelType::DIDx); + } + return { + dropout_grad, + linear1_grads.grad_w, + linear1_grads.grad_b, + sdpa_grad.grad_query, + sdpa_grad.grad_key, + sdpa_grad.grad_value, + linear0_grads.grad_w, + linear0_grads.grad_b, + linear0_grads.grad_x}; +} + +std::unique_ptr DistributedTransformer::forward( + DataType dtype, + bool sequence_parallel) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto mesh = DeviceMesh::createForNumDevices(D); + + TensorView* x = sequence_parallel + ? makeContigConcreteTensor({D, B * S / D, E}, dtype) + : makeContigConcreteTensor({B * S, E}, dtype); + TensorView* ln0_w = makeContigTensor(1); + TensorView* ln0_b = makeContigTensor(1); + TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); + TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); + TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); + TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); + TensorView* ln1_w = makeContigTensor(1); + TensorView* ln1_b = makeContigTensor(1); + TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); + TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); + TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); + TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); + + fusion->addInput(x); + fusion->addInput(ln0_w); + fusion->addInput(ln0_b); + fusion->addInput(mha_w0); + fusion->addInput(mha_b0); + fusion->addInput(mha_w1); + fusion->addInput(mha_b1); + fusion->addInput(ln1_w); + fusion->addInput(ln1_b); + fusion->addInput(mlp_w0); + fusion->addInput(mlp_b0); + fusion->addInput(mlp_w1); + fusion->addInput(mlp_b1); + + constexpr float kEps = 1e-5; + auto eps = IrBuilder::create(kEps); + std::vector norm_shape{E}; + + auto ln_input = castOp(DataType::Float, x); + auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); + auto mha_in = castOp(dtype, ln0.output); + auto mha_tvs = + mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, sequence_parallel); + auto resid0 = add(ln_input, mha_tvs.output); + auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); + auto mlp_in = castOp(dtype, ln1.output); + auto mlp_tvs = + mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, sequence_parallel); + auto resid1 = add(resid0, mlp_tvs.output); + resid1 = castOp(dtype, resid1); + + fusion->addOutput(ln0.output); + fusion->addOutput(mha_tvs.output); + fusion->addOutput(ln1.output); + fusion->addOutput(mlp_tvs.output); + fusion->addOutput(resid1); + + x->setDeviceMesh(mesh); + if (sequence_parallel) { + // Input arrives sharded + x->axis(0)->parallelize(ParallelType::DIDx); + // Propagate SP shardings from x through layernorms, dropouts, residual + // adds. Even though mha_in is part of the boundary set, residuals allow the + // shardings to propagate up the graph so we must cut off the propagation at + // the outputs of reduce scatters (mha and mlp matmul1) + shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); + // Propagate TP sharding for MLP and MHA from sharded weights. We do not + // need to shard from mha_b0 or mlp_b0 because they are only consumed by + // their respective linear0 expression which is sharded from *_w0. + shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); + shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); + shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); + shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); + } else { + // TP only shardings + // Layernorm, residuals, are all replicated like x. shardBetween + // shards all tvs reachable from x, so the input and output tvs must + // be in the boundary set. + shardBetween({x}, {mha_in, mha_tvs.output, mlp_in, mlp_tvs.output}, x); + // TP sharded regions within mha and mlp + shardBetween({mha_in}, {mha_tvs.output}, mha_w0); + shardBetween({mlp_in}, {mlp_tvs.output}, mlp_w0); + } + + return std::make_unique(std::move(fusion)); +} + +std::unique_ptr DistributedTransformer::backward( + DataType dtype) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + const auto mesh = DeviceMesh::createForNumDevices(D); + std::vector norm_shape{E}; + + TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); + TensorView* grad = makeContigTensor(2, dtype); + TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); + TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); + TensorView* mlp_w0 = makeContigTensor(3, dtype); + TensorView* mlp_w1 = makeContigTensor(3, dtype); + TensorView* mha_mask = makeContigTensor(2, DataType::Bool); + TensorView* mlp_mask = makeContigTensor(2, DataType::Bool); + TensorView* mha_sdpa_out = makeConcreteTensor({D, B, H / D, S, E / H}, dtype); + TensorView* mha_sdpa_log_sumexp = + makeContigConcreteTensor({D, B, H / D, S}, DataType::Float); + TensorView* mha_sdpa_seed = makeSymbolicTensor({}, DataType::Int); + TensorView* mha_sdpa_offset = makeSymbolicTensor({}, DataType::Int); + TensorView* ln1_w = makeContigTensor(1); + TensorView* ln1_b = makeContigTensor(1); + TensorView* ln1_mean = makeConcreteTensor({B * S, 1}); + TensorView* ln1_rstd = makeConcreteTensor({B * S, 1}); + TensorView* ln0_w = makeContigTensor(1); + TensorView* ln0_b = makeContigTensor(1); + TensorView* ln0_mean = makeConcreteTensor({B * S, 1}); + TensorView* ln0_rstd = makeConcreteTensor({B * S, 1}); + TensorView* mha_linear0 = makeContigTensor(3, dtype); + TensorView* mha_linear1 = makeContigTensor(2); + TensorView* mlp_linear0 = makeContigTensor(3, dtype); + + fusion->addInput(x); + fusion->addInput(grad); + fusion->addInput(mha_w0); + fusion->addInput(mha_w1); + fusion->addInput(mlp_w0); + fusion->addInput(mlp_w1); + fusion->addInput(mlp_mask); + fusion->addInput(mha_mask); + fusion->addInput(mha_sdpa_out); + fusion->addInput(mha_sdpa_log_sumexp); + fusion->addInput(mha_sdpa_seed); + fusion->addInput(mha_sdpa_offset); + fusion->addInput(ln1_w); + fusion->addInput(ln1_b); + fusion->addInput(ln1_mean); + fusion->addInput(ln1_rstd); + fusion->addInput(ln0_w); + fusion->addInput(ln0_b); + fusion->addInput(ln0_mean); + fusion->addInput(ln0_rstd); + fusion->addInput(mha_linear0); + fusion->addInput(mha_linear1); + fusion->addInput(mlp_linear0); + + // Activation recomputation: mlp gelu, dropouts, and + // partially recompute layer norms using cached statistics. + auto ln0_in = castOp(DataType::Float, x); + auto ln0 = layer_norm_with_cached_statistics( + ln0_in, ln0_mean, ln0_rstd, norm_shape, ln0_w, ln0_b); + auto mha_in = castOp(dtype, ln0); + + Val* dropout_scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); + // Use input mha_mask to implement dropout + auto mha_out = mul(mha_linear1, mha_mask); + mha_out = mul(mha_out, dropout_scale); + auto resid0 = add(ln0_in, mha_out); + auto ln1 = layer_norm_with_cached_statistics( + resid0, ln1_mean, ln1_rstd, norm_shape, ln1_w, ln1_b); + auto mlp_in = castOp(dtype, ln1); + + // Backwards + auto grad_float = castOp(DataType::Float, grad); + auto mlp_grads = mlp_backwards( + grad_float, mlp_in, mlp_mask, mlp_w0, mlp_w1, mlp_linear0, mesh); + auto ln1_grads = layer_norm_backward( + castOp(DataType::Float, mlp_grads[6]), + resid0, + norm_shape, + ln1_mean, + ln1_rstd, + ln1_w, + ln1_b, + {true, true, true}); + auto resid1_grad = add(ln1_grads.grad_input, grad_float); + auto mha_grads = mha_backwards( + mha_in, + mha_w0, + mha_w1, + mha_mask, + mha_sdpa_out, + mha_sdpa_log_sumexp, + mha_sdpa_seed, + mha_sdpa_offset, + resid1_grad, + mha_linear0, + mesh); + auto ln0_grads = layer_norm_backward( + castOp(DataType::Float, mha_grads[8]), + ln0_in, + norm_shape, + ln0_mean, + ln0_rstd, + ln0_w, + ln0_b, + {true, true, true}); + auto dx = add(ln0_grads.grad_input, resid1_grad); + dx = castOp(dtype, dx); + + fusion->addOutput(mlp_grads[1]); // mlp linear1 weight grad + fusion->addOutput(mlp_grads[2]); // mlp linear1 bias grad + fusion->addOutput(mlp_grads[4]); // mlp linear0 weight grad + fusion->addOutput(mlp_grads[5]); // mlp linear0 bias grad + fusion->addOutput(ln1_grads.grad_weight); + fusion->addOutput(ln1_grads.grad_bias); + fusion->addOutput(mha_grads[1]); // mha linear1 weight grad + fusion->addOutput(mha_grads[2]); // mha linear1 bias grad + fusion->addOutput(mha_grads[6]); // mha linear0 weight grad + fusion->addOutput(mha_grads[7]); // mha linear0 bias grad + fusion->addOutput(ln0_grads.grad_weight); + fusion->addOutput(ln0_grads.grad_bias); + fusion->addOutput(dx); // transformer grad input + + // Sharding annotations for input and output TVs not sharded + // by mlp_backward or mha_backward + for (auto* tv : + {ln0_w, + ln0_b, + ln0_mean, + ln0_rstd, + ln1_w, + ln1_b, + ln1_mean, + ln1_rstd, + ln1_grads.grad_weight, + ln1_grads.grad_bias, + ln0_grads.grad_weight, + ln0_grads.grad_bias, + ln0_grads.grad_input}) { + tv->setDeviceMesh(mesh); + } + + // Sharded inputs to outputs + shardBetween( + {mha_w0, mha_w1, mha_sdpa_out}, + {mha_grads[1], mha_grads[6], mha_grads[7]}, + mha_w0); + shardBetween( + {mlp_w0, mlp_w1}, {mlp_grads[1], mlp_grads[4], mlp_grads[5]}, mlp_w0); + + // Unsharded inputs to outputs + shardBetween( + {x, + grad, + mha_mask, + mlp_mask, + mha_linear1, + ln0_mean, + ln0_w, + ln0_b, + ln1_mean, + ln1_w, + ln1_b}, + {mlp_grads[2], + ln1_grads.grad_weight, + ln1_grads.grad_bias, + mha_grads[2], + ln0_grads.grad_weight, + ln0_grads.grad_bias, + dx}, + x); + + return std::make_unique(std::move(fusion)); +} +} // namespace nvfuser diff --git a/tests/cpp/multidevice_transformer.h b/tests/cpp/multidevice_transformer.h new file mode 100644 index 00000000000..ed29dd09e36 --- /dev/null +++ b/tests/cpp/multidevice_transformer.h @@ -0,0 +1,98 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include + +#include + +namespace nvfuser { + +struct MlpResult { + TensorView* linear0; + TensorView* gelu; + TensorView* matmul1; + TensorView* linear1; + TensorView* output; +}; + +struct MhaResult { + TensorView* linear0; + TensorView* sdpa; + TensorView* matmul1; + TensorView* linear1; + TensorView* output; +}; + +class DistributedTransformer { + public: + DistributedTransformer( + int64_t num_devices, + int64_t batch_size, + int64_t embedding_size, + int64_t number_heads, + int64_t sequence_length, + double dropout_prob = 0.1, + double sdpa_dropout_prob = 0.1) + : D(num_devices), + B(batch_size), + E(embedding_size), + H(number_heads), + S(sequence_length), + kDropoutProb(dropout_prob), + kSdpaProb(sdpa_dropout_prob) {} + + std::unique_ptr forward( + DataType dtype, + bool sequence_parallel = false); + std::unique_ptr backward(DataType dtype); + + MlpResult mlp( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel = false); + + MhaResult mha( + TensorView* x, + TensorView* w0, + TensorView* b0, + TensorView* w1, + TensorView* b1, + const DeviceMesh& mesh, + bool sequence_parallel = false); + + std::vector mlp_backwards( + TensorView* grad, + TensorView* x, + TensorView* mask, + TensorView* w0, + TensorView* w1, + TensorView* linear0, + const DeviceMesh& mesh); + + std::vector mha_backwards( + TensorView* x, + TensorView* w0, + TensorView* w1, + TensorView* mask, + TensorView* sdpa_output, + TensorView* sdpa_log_sumexp, + TensorView* sdpa_seed, + TensorView* sdpa_offset, + TensorView* grad, + TensorView* linear0, + const DeviceMesh& mesh); + + const int64_t D, B, E, H, S; + const double kDropoutProb; + const double kSdpaProb; + static constexpr double kSdpaScale = 1e-3; +}; +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 0f39ae6f6e5..6ccb217137f 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -13,25 +13,33 @@ #include #include #include +#include #include namespace nvfuser { -constexpr int64_t B = 2, E = 768, H = 16, S = 128; +namespace { +// Note: We test on smaller model and input sizes to avoid high error +// accumulation for validation. +static constexpr int64_t B = 2, E = 768, H = 16, S = 128; +// Note: Dropout probabilities are set to 0. Since the dropout mask is sharded +// it throws off the seed offset between the sharded nvFuser program and the +// unsharded reference. +static constexpr double kDropoutProb = 0.0, kSdpaProb = 0.0, kSdpaScale = 1e-3; // Note parameters scaled by kParamScale following weight initialization // recommendations: // https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config.initializer_range -// Note: Sdpa probability is set to 0. Since the dropout mask is sharded it -// throws off the seed offset between the sharded nvFuser program and the -// unsharded reference. -constexpr double kDropoutProb = 0.0, kParamScale = 0.02, kSdpaProb = 0.0, - kSdpaScale = 1e-3; +static constexpr double kParamScale = 0.02; +} // namespace class DistributedTransformerTest : public MultiDeviceTest, public testing::WithParamInterface { protected: - DistributedTransformerTest() : D(communicator_->size()) {} + DistributedTransformerTest() : D(communicator_->size()) { + model = std::make_unique( + D, B, E, H, S, kDropoutProb, kSdpaProb); + } void SetUp() override { MultiDeviceTest::SetUp(); @@ -41,6 +49,7 @@ class DistributedTransformerTest } const int64_t D; // number of devices + std::unique_ptr model; }; namespace { @@ -271,405 +280,6 @@ std::vector reference_mha_backwards( linear0}; return tensors; } - -struct MlpResult { - TensorView* linear0; - TensorView* gelu; - TensorView* matmul1; - TensorView* linear1; - TensorView* output; -}; - -MlpResult mlp( - TensorView* x, - TensorView* w0, - TensorView* b0, - TensorView* w1, - TensorView* b1, - const DeviceMesh& mesh, - bool sequence_parallel = false) { - const DataType dtype = w0->dtype(); - - if (sequence_parallel) { - // Input arrives sharded and must be allgathered back - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - x = set(x); // allgather - x->axis(0)->parallelize(ParallelType::Serial); - // Reshape back to 2D. This is uncessary except to keep - // the shapes of linear0 the same for TP and TP+SP. - auto D = w0->axis(0)->extent()->value().as(); - x = reshape(x, {D, B * S / D, E}, {B * S, E}); - } - // Linear 0 - TensorView* linear0 = linear(x, w0, b0); - // GeLU - TensorView* gelu = tanh_gelu(castOp(DataType::Float, linear0)); - gelu = castOp(dtype, gelu); - // Linear 1 - TensorView* local_matmul1 = matmul(gelu, transpose(w1, 1, 2)); - if (sequence_parallel) { - // Remove after https://github.com/NVIDIA/Fuser/issues/2563 - // Reshape to explicitly pull the sharded axis into the logical domain - auto D = w0->axis(0)->extent()->value().as(); - local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); - } - TensorView* matmul1 = sum(local_matmul1, {0}); // Allreduce or Reduce scatter - std::vector bcast_mask(matmul1->nDims() - 1, true); - bcast_mask[matmul1->nDims() - 2] = false; - TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); - // Dropout - Val* prob = IrBuilder::create(1.0 - kDropoutProb); - Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - TensorView* dropout_result = dropout(linear1, prob, scale).output; - - // Tensor parallel shardings - for (auto* tv : {w0, b0, w1}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - for (auto* tv : {x, b1}) { - tv->setDeviceMesh(mesh); - } - - // Sequence parallel shardings - if (sequence_parallel) { - matmul1->setDeviceMesh(mesh); - matmul1->axis(1)->parallelize(ParallelType::DIDx); - } - - return {linear0, gelu, matmul1, linear1, dropout_result}; -} - -struct MhaResult { - TensorView* linear0; - TensorView* sdpa; - TensorView* matmul1; - TensorView* linear1; - TensorView* output; -}; - -MhaResult mha( - TensorView* x, - TensorView* w0, - TensorView* b0, - TensorView* w1, - TensorView* b1, - const DeviceMesh& mesh, - bool sequence_parallel = false) { - const auto D = w0->axis(0)->extent()->value().as(); - auto dtype = w0->dtype(); - - if (sequence_parallel) { - // Input arrives sharded and must be allgathered back - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - x = set(x); // allgather - x->axis(0)->parallelize(ParallelType::Serial); - // Reshape is uncessary, it is here to keep shapes with TP and TP+SP the - // same for validation. - x = reshape(x, {D, B * S / D, E}, {B * S, E}); - } - - TensorView* linear0 = linear(x, w0, b0); - // Forming the q,k,v vectors: - TensorView* qkv_cat = - reshape(linear0, {D, B * S, 3 * E / D}, {D, B, S, 3 * E / D}); - std::vector qkv = chunk(qkv_cat, 3, -1); - for (auto i : c10::irange(3)) { - qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); - qkv[i] = transpose(qkv[i], 2, 3); - } - // SDPA - SdpfaFwdResult sdpa = sdpfa_fwd( - qkv[0], - qkv[1], - qkv[2], - IrBuilder::create(kSdpaProb), - IrBuilder::create(true), - IrBuilder::create(kSdpaScale)); - TensorView* sdpa_output = sdpa.output; - // Linear 1 - TensorView* sdpa_transpose = transpose(sdpa_output, 2, 3); - TensorView* sdpa_reshape = - reshape(sdpa_transpose, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* local_matmul1 = matmul(sdpa_reshape, transpose(w1, 1, 2)); - if (sequence_parallel) { - // Remove after https://github.com/NVIDIA/Fuser/issues/2563 - // Reshape to explicitly pull the sharded axis into the logical domain - auto D = w0->axis(0)->extent()->value().as(); - local_matmul1 = reshape(local_matmul1, {D, B * S, E}, {D, D, B * S / D, E}); - } - TensorView* matmul1 = sum(local_matmul1, {0}); // allreduce - std::vector bcast_mask(matmul1->nDims() - 1, true); - bcast_mask[matmul1->nDims() - 2] = false; - TensorView* linear1 = add(matmul1, broadcast(b1, bcast_mask)); - // Dropout - Val* prob = IrBuilder::create(1.0 - kDropoutProb); - Val* scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - TensorView* dropout_result = dropout(linear1, prob, scale).output; - - // Tensor parallel shardings - for (auto tv : {x, b1}) { - tv->setDeviceMesh(mesh); - } - for (auto tv : {w0, b0, w1}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - // Sequence parallel sharding. - if (sequence_parallel) { - matmul1->setDeviceMesh(mesh); - matmul1->axis(1)->parallelize(ParallelType::DIDx); - } - - return {linear0, sdpa_output, matmul1, linear1, dropout_result}; -} - -// TODO: These linear_backwards helper functions can be merged once -// we do not have logically split rfactor domain. -struct LinearBackwardsResult { - TensorView* grad_x; - TensorView* grad_w; - TensorView* grad_b; -}; - -// x format: [i0, i1] dtype -// weight format: [DID(D), i2/D, i1] dtype -// grad format: [DID(D) i0, i2/D] float or dtype -// outputs: grad_x [i0, i1] dtype -// grad_w [DID i2/D, i1] dtype -// grad_b [DID i2/2] dtype -LinearBackwardsResult linear_backwards( - TensorView* x, - TensorView* w, - TensorView* grad) { - DataType dtype = w->dtype(); - TensorView* grad_f = maybeCastOp(DataType::Float, grad); - TensorView* grad_q = maybeCastOp(dtype, grad); - TensorView* grad_x_partials = matmul(grad_q, w); - TensorView* grad_x = sum(grad_x_partials, {0}); // allreduce - TensorView* grad_q_t = transpose(grad_q, 1, 2); - TensorView* grad_w = matmul(grad_q_t, x); - TensorView* grad_b = sum(grad_f, {1}); - grad_b = castOp(dtype, grad_b); - - return {grad_x, grad_w, grad_b}; -} - -// x format: [DID, i0, i1/D] dtype -// weight format: [DID, i2, i1/D] dtype -// grad format: [i0, i2] float -// outputs: grad_x [DID i0, i1/D] dtype -// grad_w [DID, i2, i1/D] dtype -// grad_b [i2] dtype -LinearBackwardsResult sharded_linear_backwards( - TensorView* x, - TensorView* w, - TensorView* grad) { - DataType dtype = w->dtype(); - TensorView* grad_q = castOp(dtype, grad); - TensorView* grad_x = matmul(grad_q, w); - TensorView* grad_t = transpose(grad_q, 0, 1); - TensorView* grad_w = matmul(grad_t, x); - TensorView* grad_b = sum(grad, {0}); - grad_b = castOp(dtype, grad_b); - - return {grad_x, grad_w, grad_b}; -} - -// Forward layer_norm with cached mean_bcast and invstd tensors to avoid -// recomputing Welford. For use in backwards pass. -TensorView* layer_norm_with_cached_statistics( - TensorView* x, - TensorView* mean_bcast, - TensorView* invstd, - const std::vector& norm_shape, - TensorView* weight, - TensorView* bias) { - const int64_t kNumberOfDims = - (int64_t)TensorDomain::noReductions(x->getLogicalDomain()).size(); - const int64_t kOuterNumDims = kNumberOfDims - norm_shape.size(); - std::vector outer_broadcast_mask(kNumberOfDims, false); - for (const auto idx : c10::irange(kOuterNumDims)) { - outer_broadcast_mask[idx] = true; - } - - auto x_sub_mean = sub(x, mean_bcast); - auto y = mul(x_sub_mean, invstd); - - auto weight_bcast = broadcast(weight, outer_broadcast_mask); - y = mul(y, weight_bcast); - auto bias_bcast = broadcast(bias, outer_broadcast_mask); - return add(y, bias_bcast); -} - -// Backwards MLP block. -std::vector mlp_backwards( - TensorView* grad, - TensorView* x, - TensorView* mask, - TensorView* w0, - TensorView* w1, - TensorView* linear0, - const DeviceMesh& mesh) { - DataType dtype = w0->dtype(); - - // Activation recomputation: Always recompute gelu - TensorView* gelu = castOp(dtype, tanh_gelu(castOp(DataType::Float, linear0))); - - // Backwards pass - constexpr double kScale = 1.0 / (1.0 - kDropoutProb); - Val* dropout_scale = IrBuilder::create(kScale); - TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); - auto linear1_grads = sharded_linear_backwards(gelu, w1, dropout_grad); - TensorView* matmul1_grad_x_ = castOp(DataType::Float, linear1_grads.grad_x); - TensorView* gelu_grad = tanh_gelu_backward(matmul1_grad_x_, linear0); - auto linear0_grads = linear_backwards(x, w0, gelu_grad); - - // Manaul sharding annotations - for (auto tv : - {x, - grad, - mask, - dropout_grad, - linear1_grads.grad_b, - linear0_grads.grad_x}) { - tv->setDeviceMesh(mesh); - } - - for (auto tv : - {w0, - w1, - linear0, - linear1_grads.grad_x, - linear1_grads.grad_w, - gelu_grad, - linear0_grads.grad_w, - linear0_grads.grad_b}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - - std::vector outputs = { - dropout_grad, - linear1_grads.grad_w, - linear1_grads.grad_b, - gelu_grad, - linear0_grads.grad_w, - linear0_grads.grad_b, - linear0_grads.grad_x}; - return outputs; -} - -std::vector mha_backwards( - TensorView* x, - TensorView* w0, - TensorView* w1, - TensorView* mask, - TensorView* sdpa_output, - TensorView* sdpa_log_sumexp, - TensorView* sdpa_seed, - TensorView* sdpa_offset, - TensorView* grad, - TensorView* linear0, - const DeviceMesh& mesh) { - DataType dtype = w0->dtype(); - const auto D = w0->axis(0)->extent()->value().as(); - // Reform qkv from linear0 output - TensorView* qkv_cat = reshape( - castOp(DataType::Float, linear0), - {D, B * S, 3 * E / D}, - {D, B, S, 3 * E / D}); - std::vector qkv = chunk(qkv_cat, 3, -1); - for (auto i : c10::irange(3)) { - qkv[i] = reshape(qkv[i], {D, B, S, E / D}, {D, B, S, H / D, E / H}); - qkv[i] = transpose(qkv[i], 2, 3); - qkv[i] = castOp(dtype, qkv[i]); - qkv[i]->setDeviceMesh(mesh); - qkv[i]->axis(0)->parallelize(ParallelType::DIDx); - } - - // dropout backwards - constexpr double kScale = 1.0 / (1.0 - kDropoutProb); - auto dropout_scale = IrBuilder::create(kScale); - TensorView* dropout_grad = dropout_backward(grad, mask, dropout_scale); - - // linear1 backwards - TensorView* sdpa_output_reshape = - transpose(sdpa_output, 2, 3); // D, B, S, H/D, E/H - sdpa_output_reshape = - reshape(sdpa_output_reshape, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - auto linear1_grads = - sharded_linear_backwards(sdpa_output_reshape, w1, dropout_grad); - - // SDPA backwards - TensorView* linear1_x_grad = - reshape(linear1_grads.grad_x, {D, B * S, E / D}, {D, B, S, H / D, E / H}); - linear1_x_grad = transpose(linear1_x_grad, 2, 3); // D, B, H/D, S, E/H - // Explicitly shard inputs before SDPA backward node - for (auto tv : {linear1_x_grad, sdpa_output, sdpa_log_sumexp}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - auto sdpa_grad = sdpfa_bwd( - linear1_x_grad, - qkv[0], - qkv[1], - qkv[2], - sdpa_output, - sdpa_log_sumexp, - /*dropout_p=*/IrBuilder::create(kSdpaProb), - /*is_causal=*/IrBuilder::create(true), - sdpa_seed, - sdpa_offset, - /*scale=*/IrBuilder::create(kSdpaScale)); - - TensorView* q_grad = transpose(sdpa_grad.grad_query, 2, 3); - q_grad = reshape(q_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* v_grad = transpose(sdpa_grad.grad_value, 2, 3); - v_grad = reshape(v_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* k_grad = transpose(sdpa_grad.grad_key, 2, 3); - k_grad = reshape(k_grad, {D, B, S, H / D, E / H}, {D, B * S, E / D}); - TensorView* kqv_grad = cat({k_grad, q_grad, v_grad}, -1); - auto linear0_grads = linear_backwards(x, w0, kqv_grad); - - for (auto tv : - {x, - mask, - grad, - dropout_grad, - linear1_grads.grad_b, - linear0_grads.grad_x}) { - tv->setDeviceMesh(mesh); - } - for (auto tv : - {w0, - w1, - sdpa_output, - sdpa_log_sumexp, - linear0, - linear1_grads.grad_x, - linear1_grads.grad_w, - linear0_grads.grad_w, - linear0_grads.grad_b, - sdpa_grad.grad_query, - sdpa_grad.grad_key, - sdpa_grad.grad_value}) { - tv->setDeviceMesh(mesh); - tv->axis(0)->parallelize(ParallelType::DIDx); - } - return { - dropout_grad, - linear1_grads.grad_w, - linear1_grads.grad_b, - sdpa_grad.grad_query, - sdpa_grad.grad_key, - sdpa_grad.grad_value, - linear0_grads.grad_w, - linear0_grads.grad_b, - linear0_grads.grad_x}; -} } // namespace TEST_P(DistributedTransformerTest, MLP_Layer) { @@ -695,7 +305,7 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tvsout = mlp(tvx, tvw0, tvb0, tvw1, tvb1, mesh); + auto tvsout = model->mlp(tvx, tvw0, tvb0, tvw1, tvb1, mesh); fusion->addOutput(tvsout.linear0); fusion->addOutput(tvsout.gelu); @@ -768,7 +378,7 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { // Note only the sequence (S) dimension that is sharded // but to avoid DID parallelizations of inner logical axes // B*S is sharded. - auto tvsout = mlp(x, w0, b0, w1, b1, mesh, true); + auto tvsout = model->mlp(x, w0, b0, w1, b1, mesh, true); fusion->addInput(x); fusion->addInput(w0); @@ -842,7 +452,7 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tv_outs = mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh); + auto tv_outs = model->mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh); fusion->addOutput(tv_outs.linear0); fusion->addOutput(tv_outs.sdpa); @@ -907,7 +517,7 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { fusion->addInput(tvw1); fusion->addInput(tvb1); - auto tv_outs = mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh, true); + auto tv_outs = model->mha(tvx, tvw0, tvb0, tvw1, tvb1, mesh, true); fusion->addOutput(tv_outs.linear0); fusion->addOutput(tv_outs.sdpa); @@ -974,7 +584,7 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { fusion->addInput(linear0); std::vector tv_outs = - mlp_backwards(grad, x, mask, w0, w1, linear0, mesh); + model->mlp_backwards(grad, x, mask, w0, w1, linear0, mesh); for (TensorView* tv : tv_outs) { fusion->addOutput(tv); @@ -1056,7 +666,7 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { fusion->addInput(tvsdpa_offset); fusion->addInput(linear0); - auto tvouts = mha_backwards( + auto tvouts = model->mha_backwards( tvx, tvw0, tvw1, @@ -1134,74 +744,10 @@ TEST_P(DistributedTransformerTest, Forward_SP) { } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); const auto mesh = DeviceMesh::createForNumDevices(D); - - TensorView* x = makeContigConcreteTensor({D, B * S / D, E}, dtype); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* mlp_w0 = makeContigConcreteTensor({D, 4 * E / D, E}, dtype); - TensorView* mlp_b0 = makeContigConcreteTensor({D, 4 * E / D}, dtype); - TensorView* mlp_w1 = makeContigConcreteTensor({D, E, 4 * E / D}, dtype); - TensorView* mlp_b1 = makeContigConcreteTensor({E}, dtype); - - fusion->addInput(x); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(mha_w0); - fusion->addInput(mha_b0); - fusion->addInput(mha_w1); - fusion->addInput(mha_b1); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_b0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_b1); - constexpr float kEps = 1e-5; - auto eps = IrBuilder::create(kEps); std::vector norm_shape{E}; - auto ln_input = castOp(DataType::Float, x); - auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); - auto mha_in = castOp(dtype, ln0.output); - auto mha_tvs = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh, true); - auto resid0 = add(ln_input, mha_tvs.output); - auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); - auto mlp_in = castOp(dtype, ln1.output); - auto mlp_tvs = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh, true); - auto resid1 = add(resid0, mlp_tvs.output); - resid1 = castOp(dtype, resid1); - - fusion->addOutput(ln0.output); - fusion->addOutput(mha_tvs.output); - fusion->addOutput(ln1.output); - fusion->addOutput(mlp_tvs.output); - fusion->addOutput(resid1); - - x->setDeviceMesh(mesh); - x->axis(0)->parallelize(ParallelType::DIDx); - // Propagate SP shardings from x through layernorms, dropouts, residual adds. - // Even though mha_in is part of the boundary set, residuals allow the - // shardings to propagate up the graph so we must cut off the propagation at - // the outputs of reduce scatters (mha and mlp matmul1) - shardBetween({x}, {mha_in, mlp_in, mha_tvs.matmul1, mlp_tvs.matmul1}, x); - // Propagate TP sharding for MLP and MHA from sharded weights. We do not need - // to shard from mha_b0 or mlp_b0 because they are only consumed by their - // respective linear0 expression which is sharded from *_w0. - shardBetween({mha_w0}, {mha_tvs.matmul1}, mha_w0); - shardBetween({mha_w1}, {mha_tvs.matmul1}, mha_w1); - shardBetween({mlp_w0}, {mlp_tvs.matmul1}, mlp_w0); - shardBetween({mlp_w1}, {mlp_tvs.matmul1}, mlp_w1); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1256,9 +802,9 @@ TEST_P(DistributedTransformerTest, Forward_SP) { shardTensor(mlp_out_, 0, mesh).unsqueeze(0), shardTensor(at_out, 0, mesh).unsqueeze(0)}; - FusionExecutorCache fec(std::move(fusion)); + auto fec = model->forward(dtype, true); at::manual_seed(getATenRandomSeed()); - auto outputs = fec.runFusionWithInputs(inputs); + auto outputs = fec->runFusionWithInputs(inputs); validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); } @@ -1269,67 +815,10 @@ TEST_P(DistributedTransformerTest, Forward) { } auto dtype = GetParam(); at::ScalarType at_dtype = data_type_to_aten(dtype); - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); const auto mesh = DeviceMesh::createForNumDevices(D); - - TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_b0 = makeContigConcreteTensor({D, 3 * E / D}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mha_b1 = makeContigConcreteTensor({E}, dtype); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* mlp_w0 = makeContigTensor(3, dtype); - TensorView* mlp_b0 = makeContigTensor(2, dtype); - TensorView* mlp_w1 = makeContigTensor(3, dtype); - TensorView* mlp_b1 = makeContigTensor(1, dtype); - - fusion->addInput(x); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(mha_w0); - fusion->addInput(mha_b0); - fusion->addInput(mha_w1); - fusion->addInput(mha_b1); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_b0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_b1); - constexpr float kEps = 1e-5; - auto eps = IrBuilder::create(kEps); std::vector norm_shape{E}; - auto ln_input = castOp(DataType::Float, x); - auto ln0 = layer_norm(ln_input, norm_shape, ln0_w, ln0_b, eps); - auto mha_in = castOp(dtype, ln0.output); - auto mha_out = mha(mha_in, mha_w0, mha_b0, mha_w1, mha_b1, mesh).output; - auto resid0 = add(ln_input, mha_out); - auto ln1 = layer_norm(resid0, norm_shape, ln1_w, ln1_b, eps); - auto mlp_in = castOp(dtype, ln1.output); - auto mlp_out = mlp(mlp_in, mlp_w0, mlp_b0, mlp_w1, mlp_b1, mesh).output; - auto resid1 = add(resid0, mlp_out); - resid1 = castOp(dtype, resid1); - - fusion->addOutput(ln0.output); - fusion->addOutput(mha_out); - fusion->addOutput(ln1.output); - fusion->addOutput(mlp_out); - fusion->addOutput(resid1); - - for (auto tv : {x, ln0.output, ln1.output, resid1}) { - tv->setDeviceMesh(mesh); - } - - shardBetween({mha_in->definition()}, {mha_out->definition()}, mha_w0); - shardBetween({mlp_in->definition()}, {mlp_out->definition()}, mlp_w0); - shardBetween({x}, {mha_in}, x); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1380,9 +869,9 @@ TEST_P(DistributedTransformerTest, Forward) { std::vector expected_outputs = { ln0_out_, mha_out_, ln1_out_, mlp_out_, at_out}; - FusionExecutorCache executor_cache(std::move(fusion)); + auto executor_cache = model->forward(dtype); at::manual_seed(getATenRandomSeed()); - auto outputs = executor_cache.runFusionWithInputs(inputs); + auto outputs = executor_cache->runFusionWithInputs(inputs); validate(expected_outputs, outputs, {1e-4, 0.02, 0.04, 0.04, 0.04}); } @@ -1399,172 +888,6 @@ TEST_P(DistributedTransformerTest, Backward) { constexpr float kEps = 1e-5; std::vector norm_shape{E}; - TensorView* x = makeContigConcreteTensor({B * S, E}, dtype); - TensorView* grad = makeContigTensor(2, dtype); - TensorView* mha_w0 = makeContigConcreteTensor({D, 3 * E / D, E}, dtype); - TensorView* mha_w1 = makeContigConcreteTensor({D, E, E / D}, dtype); - TensorView* mlp_w0 = makeContigTensor(3, dtype); - TensorView* mlp_w1 = makeContigTensor(3, dtype); - TensorView* mha_mask = makeContigTensor(2, DataType::Bool); - TensorView* mlp_mask = makeContigTensor(2, DataType::Bool); - TensorView* mha_sdpa_out = makeConcreteTensor({D, B, H / D, S, E / H}, dtype); - TensorView* mha_sdpa_log_sumexp = - makeContigConcreteTensor({D, B, H / D, S}, DataType::Float); - TensorView* mha_sdpa_seed = makeSymbolicTensor({}, DataType::Int); - TensorView* mha_sdpa_offset = makeSymbolicTensor({}, DataType::Int); - TensorView* ln1_w = makeContigTensor(1); - TensorView* ln1_b = makeContigTensor(1); - TensorView* ln1_mean = makeConcreteTensor({B * S, 1}); - TensorView* ln1_rstd = makeConcreteTensor({B * S, 1}); - TensorView* ln0_w = makeContigTensor(1); - TensorView* ln0_b = makeContigTensor(1); - TensorView* ln0_mean = makeConcreteTensor({B * S, 1}); - TensorView* ln0_rstd = makeConcreteTensor({B * S, 1}); - TensorView* mha_linear0 = makeContigTensor(3, dtype); - TensorView* mha_linear1 = makeContigTensor(2); - TensorView* mlp_linear0 = makeContigTensor(3, dtype); - - fusion->addInput(x); - fusion->addInput(grad); - fusion->addInput(mha_w0); - fusion->addInput(mha_w1); - fusion->addInput(mlp_w0); - fusion->addInput(mlp_w1); - fusion->addInput(mlp_mask); - fusion->addInput(mha_mask); - fusion->addInput(mha_sdpa_out); - fusion->addInput(mha_sdpa_log_sumexp); - fusion->addInput(mha_sdpa_seed); - fusion->addInput(mha_sdpa_offset); - fusion->addInput(ln1_w); - fusion->addInput(ln1_b); - fusion->addInput(ln1_mean); - fusion->addInput(ln1_rstd); - fusion->addInput(ln0_w); - fusion->addInput(ln0_b); - fusion->addInput(ln0_mean); - fusion->addInput(ln0_rstd); - fusion->addInput(mha_linear0); - fusion->addInput(mha_linear1); - fusion->addInput(mlp_linear0); - - // Activation recomputation: mlp gelu, dropouts, and - // partially recompute layer norms using cached statistics. - auto ln0_in = castOp(DataType::Float, x); - auto ln0 = layer_norm_with_cached_statistics( - ln0_in, ln0_mean, ln0_rstd, norm_shape, ln0_w, ln0_b); - auto mha_in = castOp(dtype, ln0); - - Val* dropout_scale = IrBuilder::create(1.0 / (1.0 - kDropoutProb)); - // Use input mha_mask to implement dropout - auto mha_out = mul(mha_linear1, mha_mask); - mha_out = mul(mha_out, dropout_scale); - auto resid0 = add(ln0_in, mha_out); - auto ln1 = layer_norm_with_cached_statistics( - resid0, ln1_mean, ln1_rstd, norm_shape, ln1_w, ln1_b); - auto mlp_in = castOp(dtype, ln1); - - // Backwards - auto grad_float = castOp(DataType::Float, grad); - auto mlp_grads = mlp_backwards( - grad_float, mlp_in, mlp_mask, mlp_w0, mlp_w1, mlp_linear0, mesh); - auto ln1_grads = layer_norm_backward( - castOp(DataType::Float, mlp_grads[6]), - resid0, - norm_shape, - ln1_mean, - ln1_rstd, - ln1_w, - ln1_b, - {true, true, true}); - auto resid1_grad = add(ln1_grads.grad_input, grad_float); - auto mha_grads = mha_backwards( - mha_in, - mha_w0, - mha_w1, - mha_mask, - mha_sdpa_out, - mha_sdpa_log_sumexp, - mha_sdpa_seed, - mha_sdpa_offset, - resid1_grad, - mha_linear0, - mesh); - auto ln0_grads = layer_norm_backward( - castOp(DataType::Float, mha_grads[8]), - ln0_in, - norm_shape, - ln0_mean, - ln0_rstd, - ln0_w, - ln0_b, - {true, true, true}); - auto dx = add(ln0_grads.grad_input, resid1_grad); - dx = castOp(dtype, dx); - - fusion->addOutput(mlp_grads[1]); // mlp linear1 weight grad - fusion->addOutput(mlp_grads[2]); // mlp linear1 bias grad - fusion->addOutput(mlp_grads[4]); // mlp linear0 weight grad - fusion->addOutput(mlp_grads[5]); // mlp linear0 bias grad - fusion->addOutput(ln1_grads.grad_weight); - fusion->addOutput(ln1_grads.grad_bias); - fusion->addOutput(mha_grads[1]); // mha linear1 weight grad - fusion->addOutput(mha_grads[2]); // mha linear1 bias grad - fusion->addOutput(mha_grads[6]); // mha linear0 weight grad - fusion->addOutput(mha_grads[7]); // mha linear0 bias grad - fusion->addOutput(ln0_grads.grad_weight); - fusion->addOutput(ln0_grads.grad_bias); - fusion->addOutput(dx); // transformer grad input - - // Sharding annotations for input and output TVs not sharded - // by mlp_backward or mha_backward - for (auto* tv : - {ln0_w, - ln0_b, - ln0_mean, - ln0_rstd, - ln1_w, - ln1_b, - ln1_mean, - ln1_rstd, - ln1_grads.grad_weight, - ln1_grads.grad_bias, - ln0_grads.grad_weight, - ln0_grads.grad_bias, - ln0_grads.grad_input}) { - tv->setDeviceMesh(mesh); - } - - // Sharded inputs to outputs - shardBetween( - {mha_w0, mha_w1, mha_sdpa_out}, - {mha_grads[1], mha_grads[6], mha_grads[7]}, - mha_w0); - shardBetween( - {mlp_w0, mlp_w1}, {mlp_grads[1], mlp_grads[4], mlp_grads[5]}, mlp_w0); - - // Unsharded inputs to outputs - shardBetween( - {x, - grad, - mha_mask, - mlp_mask, - mha_linear1, - ln0_mean, - ln0_w, - ln0_b, - ln1_mean, - ln1_w, - ln1_b}, - {mlp_grads[2], - ln1_grads.grad_weight, - ln1_grads.grad_bias, - mha_grads[2], - ln0_grads.grad_weight, - ln0_grads.grad_bias, - dx}, - x); - const auto options = at::TensorOptions().dtype(at_dtype).device(communicator_->device()); auto x_ = at::randn({B * S, E}, options); @@ -1667,9 +990,9 @@ TEST_P(DistributedTransformerTest, Backward) { shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; - FusionExecutorCache executor_cache(std::move(fusion)); + auto executor_cache = model->backward(dtype); at::manual_seed(getATenRandomSeed()); - auto outputs = executor_cache.runFusionWithInputs(inputs); + auto outputs = executor_cache->runFusionWithInputs(inputs); validate( expected_outputs, outputs,