From 5d60dd551ea0d6fc82b59ad2ea18087ff5bad806 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 27 Nov 2024 15:16:49 -0800 Subject: [PATCH] Try to reuse getShardedAxis --- csrc/multidevice/lower_communication.cpp | 2 +- csrc/multidevice/utils.cpp | 134 +++++++++--------- csrc/multidevice/utils.h | 6 +- tests/cpp/multidevice.cpp | 11 +- .../test_multidevice_lower_communication.cpp | 10 +- tests/cpp/test_multidevice_transformer.cpp | 124 ++++++++-------- 6 files changed, 140 insertions(+), 147 deletions(-) diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index c8068b5a113..a8fca521c4a 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -196,7 +196,7 @@ void lowerToReduceScatter( std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedAxis(output_tv); + auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped // back onto the input. The input has an reduced axis, so the scattered axis // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 334f8271aa2..0209a6b5d5f 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,48 +121,77 @@ bool isSharded(const TensorView* tv) { return is_sharded; } -std::vector unshardedSizes( - const TensorView* tv, - c10::IntArrayRef sizes) { - std::vector unsharded_sizes = sizes.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); +namespace { +// Collect device-parallel IterDomains in `domain` and return them as a +// ParallelType-to-IterDomain map. +std::unordered_map mapDeviceParallelTypeToId( + const std::vector& domain) { + std::unordered_map parallel_type_to_id; + parallel_type_to_id.reserve(kParallelTypeDIDs.size()); + for (IterDomain* id : domain) { + const ParallelType parallel_type = id->getParallelType(); if (!isParallelTypeDeviceDim(parallel_type)) { continue; } - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); NVF_ERROR( - !inputs.empty(), - "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); - NVF_ERROR( - inputs.size() == 1, - "Failed to find the single logical input to ", - alloc_id, - ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", - toDelimitedString(inputs)); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); - NVF_ERROR( - iter != tv->getLogicalDomain().end(), - "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", - inputs[0]); - - // Count the number of non-reduction IterDomains before `iter`. Reduction - // IterDomains are not materialized in the at::Tensor's shape. - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + parallel_type_to_id.try_emplace(parallel_type, id).second, + "Found multiple loop IterDomains with the same parallel type (", + parallel_type, + "): ", + toDelimitedString(domain)); + } + return parallel_type_to_id; +} +} // namespace + +int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) { + std::unordered_map parallel_type_to_id = + mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); + IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); + if (alloc_id == nullptr) { + return -1; } + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + NVF_ERROR( + !inputs.empty(), + "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); + NVF_ERROR( + inputs.size() == 1, + "Failed to find the single logical input to ", + alloc_id, + ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", + toDelimitedString(inputs)); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]); + NVF_ERROR( + iter != tv->getLogicalDomain().end(), + "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", + inputs[0]); + + // Count the number of non-reduction IterDomains before `iter`. Reduction + // IterDomains are not materialized in the at::Tensor's shape. + return std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); +} + +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + for (ParallelType parallel_type : kParallelTypeDIDs) { + const int64_t sharded_axis = getShardedAxis(tv, parallel_type); + if (sharded_axis == -1) { + continue; + } + unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type); + } return unsharded_sizes; } @@ -174,27 +203,6 @@ int64_t numDeviceDims(const TensorView* tv) { } namespace { -// Collect device-parallel IterDomains in `loop_domain` and return them as a -// ParallelType-to-IterDomain map. -std::unordered_map mapParallelTypeToId( - const std::vector& loop_domain) { - std::unordered_map parallel_type_to_id; - parallel_type_to_id.reserve(kParallelTypeDIDs.size()); - for (IterDomain* loop_id : loop_domain) { - const ParallelType parallel_type = loop_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - NVF_ERROR( - parallel_type_to_id.try_emplace(parallel_type, loop_id).second, - "Found multiple loop IterDomains with the same parallel type (", - parallel_type, - "): ", - toDelimitedString(loop_domain)); - } - return parallel_type_to_id; -} std::vector getInputsInTargetDomain( IterDomain* loop_id, @@ -294,9 +302,9 @@ bool haveDifferentShardings( // 3. Check if the two loop IterDomains are almost-exactly mapped in the // IdModel. std::unordered_map p_parallel_type_to_id = - mapParallelTypeToId(producer->getLoopDomain()); + mapDeviceParallelTypeToId(producer->getLoopDomain()); std::unordered_map c_parallel_type_to_id = - mapParallelTypeToId(consumer->getLoopDomain()); + mapDeviceParallelTypeToId(consumer->getLoopDomain()); for (const auto parallel_type : kParallelTypeDIDs) { IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type); @@ -502,16 +510,6 @@ std::set involvedDevices(Expr* expr) { return ret; } -int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - for (const auto i : c10::irange(ids.size())) { - if (ids[i]->getParallelType() == ParallelType::DIDx) { - return static_cast(i); - } - } - return -1; -} - void reorderDIDToFront(TensorView* tv) { // new position to old position std::unordered_map order_map; diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 50f0223a034..439ecfd28c2 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -123,9 +123,9 @@ int64_t requestedNumberOfDevices(Fusion*); void unshard(Fusion*); void unshard(TensorView*); -// Returns the index of the a sharded axis if none return -1. -// TODO: Assumes no merges/splits on sharded axis. -int64_t getShardedAxis(TensorView*); +// Returns the index of the sharded logical axis corresponding to +// `parallel_type`. If `tv` isn't sharded on the parallel type, returns -1. +int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index bab5cdccc5e..92c13c1c460 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -128,7 +128,8 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { return tensor; } NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv); - return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh()); + return shardTensor( + tensor, getShardedAxis(tv, ParallelType::DIDx), tv->getDeviceMesh()); } at::Tensor MultiDeviceTest::shardTensor( @@ -144,13 +145,7 @@ at::Tensor MultiDeviceTest::shardTensor( auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; - auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); - // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx - // axis in front representing the sharded extent of the tensor. - if (stride > 1) { - slice = slice.unsqueeze(0); - } - return slice; + return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 73d5a124357..e316ec0f219 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -202,7 +202,7 @@ TEST_F(LowerCollectiveTest, Allgather) { EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor)); } -TEST_F(LowerCollectiveTest, Allgather_SplitLoop) { +TEST_F(LowerCollectiveTest, Allgather_LoopSplit) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -224,12 +224,8 @@ TEST_F(LowerCollectiveTest, Allgather_SplitLoop) { at::Tensor unsharded_tensor = at::randn({num_devices * kTensorSize}, at::kFloat); - at::Tensor in_tensor = unsharded_tensor - .slice( - 0, - communicator_->deviceId() * kTensorSize, - (communicator_->deviceId() + 1) * kTensorSize) - .to(communicator_->device()); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); FusionExecutorCache fec(std::move(fusion)); at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2ef33dcdf8f..0f39ae6f6e5 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -720,14 +720,14 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { std::vector inputs = { x, - shardTensor(w0, 0, mesh), - shardTensor(b0, 0, mesh), - shardTensor(w1, 1, mesh), + shardTensor(w0, 0, mesh).unsqueeze(0), + shardTensor(b0, 0, mesh).unsqueeze(0), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -801,17 +801,17 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { auto mask_ = reference_outs[4]; std::vector inputs = { - shardTensor(x_, 0, mesh), - shardTensor(w0_, 0, mesh), - shardTensor(b0_, 0, mesh), - shardTensor(w1_, 1, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), + shardTensor(w0_, 0, mesh).unsqueeze(0), + shardTensor(b0_, 0, mesh).unsqueeze(0), + shardTensor(w1_, 1, mesh).unsqueeze(0), b1_}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -866,12 +866,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -929,17 +929,17 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { at::manual_seed(getATenRandomSeed()); auto reference_outs = reference_mha(x, w0, b0, w1, b1); std::vector inputs = { - shardTensor(x, 0, mesh), + shardTensor(x, 0, mesh).unsqueeze(0), shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1003,16 +1003,16 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { grad_, x_, mask_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), - shardTensor(linear0_, 1, mesh)}; + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), + shardTensor(linear0_, 1, mesh).unsqueeze(0)}; std::vector expected_outputs = { outs[0], // dropout grad - shardTensor(outs[1], 1, mesh), // linear1 weight grad + shardTensor(outs[1], 1, mesh).unsqueeze(0), // linear1 weight grad outs[2], // linear1 bias grad - shardTensor(outs[3], 1, mesh), // gelu grad - shardTensor(outs[4], 0, mesh), // linear0 weight grad - shardTensor(outs[5], 0, mesh), // linear0 bias grad + shardTensor(outs[3], 1, mesh).unsqueeze(0), // gelu grad + shardTensor(outs[4], 0, mesh).unsqueeze(0), // linear0 weight grad + shardTensor(outs[5], 0, mesh).unsqueeze(0), // linear0 bias grad outs[6]}; // linear0 grad x FusionExecutorCache executor_cache(std::move(fusion)); @@ -1094,22 +1094,23 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { std::vector inputs = { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), grad, mask, - shardTensor(reference_outs[0], 1, mesh), // sdpa.output - shardTensor(reference_outs[1], 1, mesh), // sdpa.log_sumexp + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), // sdpa.output + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), // sdpa.log_sumexp reference_outs[2], // sdpa.seed reference_outs[3], // sdpa.offset - shardTensor(reference_outs[13], 1, mesh) // linear0 + shardTensor(reference_outs[13], 1, mesh).unsqueeze(0) // linear0 }; std::vector expected_outputs = { reference_outs[4], // dropout grad - shardTensor(reference_outs[5], 1, mesh), // linear1 weight grad + shardTensor(reference_outs[5], 1, mesh) + .unsqueeze(0), // linear1 weight grad reference_outs[6], // linear1 bias grad - shardTensor(reference_outs[7], 1, mesh), // q grad - shardTensor(reference_outs[8], 1, mesh), // k grad - shardTensor(reference_outs[9], 1, mesh), // v grad + shardTensor(reference_outs[7], 1, mesh).unsqueeze(0), // q grad + shardTensor(reference_outs[8], 1, mesh).unsqueeze(0), // k grad + shardTensor(reference_outs[9], 1, mesh).unsqueeze(0), // v grad shardTensor(reference_outs[10].view({3, E, E}), 1, mesh) .view({1, 3 * E / D, E}), // linear0 weight grad shardTensor(reference_outs[11].view({3, E}), 1, mesh) @@ -1234,26 +1235,26 @@ TEST_P(DistributedTransformerTest, Forward_SP) { auto at_out = (resid0_ + mlp_out_).to(at_dtype); std::vector inputs = { - shardTensor(x_, 0, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), ln0_w_, ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { - shardTensor(ln0_out_, 0, mesh), - shardTensor(mha_out_, 0, mesh), - shardTensor(ln1_out_, 0, mesh), - shardTensor(mlp_out_, 0, mesh), - shardTensor(at_out, 0, mesh)}; + shardTensor(ln0_out_, 0, mesh).unsqueeze(0), + shardTensor(mha_out_, 0, mesh).unsqueeze(0), + shardTensor(ln1_out_, 0, mesh).unsqueeze(0), + shardTensor(mlp_out_, 0, mesh).unsqueeze(0), + shardTensor(at_out, 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1367,13 +1368,13 @@ TEST_P(DistributedTransformerTest, Forward) { ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { @@ -1620,13 +1621,16 @@ TEST_P(DistributedTransformerTest, Backward) { auto dx_ = (ln0_x_grad_ + resid1_grad_).to(at_dtype); auto expected_outputs = { - shardTensor(mlp_grads_[1], 1, mesh), // mlp_linear1_weight_grad + shardTensor(mlp_grads_[1], 1, mesh) + .unsqueeze(0), // mlp_linear1_weight_grad mlp_grads_[2], // mlp_linear1_bias_grad - shardTensor(mlp_grads_[4], 0, mesh), // mlp_linear0_weight_grad - shardTensor(mlp_grads_[5], 0, mesh), // mlp_linear0_bias_grad + shardTensor(mlp_grads_[4], 0, mesh) + .unsqueeze(0), // mlp_linear0_weight_grad + shardTensor(mlp_grads_[5], 0, mesh).unsqueeze(0), // mlp_linear0_bias_grad ln1_w_grad_, ln1_b_grad_, - shardTensor(mha_grads_[5], 1, mesh), // mha linear1 weight grad + shardTensor(mha_grads_[5], 1, mesh) + .unsqueeze(0), // mha linear1 weight grad mha_grads_[6], // mha linear1 bias grad shardTensor( mha_grads_[10].view({3, E, E}), 1, mesh) // failing starting here @@ -1641,13 +1645,13 @@ TEST_P(DistributedTransformerTest, Backward) { x_, grad_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(mha_w1_, 1, mesh), - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_out_[4], // mlp dropout mask mha_out_[4], // mha dropout mask - shardTensor(mha_grads_[0], 1, mesh), // sdpa output - shardTensor(mha_grads_[1], 1, mesh), // sdpa logsum_exp + shardTensor(mha_grads_[0], 1, mesh).unsqueeze(0), // sdpa output + shardTensor(mha_grads_[1], 1, mesh).unsqueeze(0), // sdpa logsum_exp mha_grads_[2], // sdpa seed mha_grads_[3], // sdpa offset ln1_w_, @@ -1658,9 +1662,9 @@ TEST_P(DistributedTransformerTest, Backward) { ln0_b_, ln0_mean_, ln0_rstd_, - shardTensor(mha_out_[0], 1, mesh), // mha linear0 + shardTensor(mha_out_[0], 1, mesh).unsqueeze(0), // mha linear0 mha_out_[2].to(at::kFloat), // mha linear1 - shardTensor(mlp_out_[0], 1, mesh) // mlp linear1 + shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; FusionExecutorCache executor_cache(std::move(fusion));