From 07be8b9b3a2279d86c8287b1411fbc46262831aa Mon Sep 17 00:00:00 2001
From: "Gao, Xiang" <qasdfgtyuiop@gmail.com>
Date: Mon, 9 Dec 2024 14:24:41 -0800
Subject: [PATCH 1/3] Update all reduction primitives to take a block_dim
 parameter (#3536)

Stacked on https://github.com/NVIDIA/Fuser/pull/3541

This PR is mostly mechanical. It just update all the reduction
primitives with a new parameter `block_dim`, which is just `blockDim`
for a kernel without warp specialization. For a kernel with warp
specialization, we use a value for `block_dim`, which is the block
dimension of the compute warps.

For better performance, it is beneficial if we can know whether we are
using the `blockDim` or a custom value at compile time. To achieve this,
I added a new helper class:

```C++
struct DefaultBlockDim {
  const uint32_t x, y, z;
  __device__ DefaultBlockDim() : x(blockDim.x), y(blockDim.y), z(blockDim.z) {}
  __device__ operator dim3() const {
    return blockDim;
  }
};
```
When a function takes `block_dim`, its type will be a template type
`BlockDimT`, which is either `DefaultBlockDim` or `dim3`, where
`DefaultBlockDim` means we are using `blockDim`, and `dim3` means we are
using the provided custom value.

With this PR, I would consider all our reductions and broadcasts
compatible with warp specialization. I expect them to just work with
warp specialization. In the future, we may want to use TMA + warp
specialization for reduction and persistent kernels.
---
 csrc/codegen.cpp                      |  34 ++++
 csrc/parallel_dimension_map.h         |   5 +
 runtime/basic_type_traits.cu          |   2 +
 runtime/block_reduction.cu            | 123 +++++++------
 runtime/block_sync_default.cu         |  22 ++-
 runtime/block_welford_outer.cu        |  16 +-
 runtime/broadcast.cu                  |  20 ++-
 runtime/fused_reduction.cu            | 244 +++++++++++++++++++++-----
 runtime/fused_welford_impl.cu         | 137 +++++++++++----
 runtime/fused_welford_impl_outer.cu   |  33 +++-
 runtime/grid_broadcast.cu             |  17 +-
 runtime/grid_reduction.cu             | 159 ++++++++++++-----
 runtime/grid_sync.cu                  |  44 +++--
 runtime/warp.cu                       |  43 +++--
 runtime/welford.cu                    |  79 ++++++---
 tests/cpp/test_circular_buffering.cpp |  13 --
 tests/cpp/test_gpu2.cpp               |   9 +-
 17 files changed, 733 insertions(+), 267 deletions(-)

diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp
index 030de26c84d..0bfb72f1a7a 100644
--- a/csrc/codegen.cpp
+++ b/csrc/codegen.cpp
@@ -1214,6 +1214,20 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     }
   }
 
+  std::string genComputeBlockDim() {
+    std::stringstream ss;
+    const auto& pdim_map = kernel_->summary().parallel_dimension_map;
+    if (!pdim_map.hasWarpSpecialization()) {
+      ss << "DefaultBlockDim()";
+    } else {
+      ss << "dim3("
+         << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDx)) << ", "
+         << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDy)) << ", "
+         << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDz)) << ")";
+    }
+    return ss.str();
+  }
+
   std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
     std::stringstream lambda;
     lambda << "[](" << data_type << " &a, " << data_type << " b) "
@@ -1252,6 +1266,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genStaticCast(genPtrType(data_type), "shared_mem"));
     NVF_ERROR(stmt->predicate() != nullptr && stmt->predicate()->hasValue());
     func_args.arg(genInline(stmt->predicate()));
+    func_args.arg(genComputeBlockDim());
 
     indent() << genCall("broadcast::blockBroadcast", template_args, func_args)
              << ";\n";
@@ -1284,6 +1299,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     NVF_ERROR(read_pred != nullptr && read_pred->hasValue());
     func_args.arg(genInline(read_pred));
     func_args.arg(genStaticCast(output->dtype(), genInline(init)));
+    func_args.arg(genComputeBlockDim());
 
     ArgumentBuilder template_args;
     if (reduction_dims.first->getParallelType() == ParallelType::TIDx &&
@@ -1349,6 +1365,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
       func_args.arg(genInline(write_pred));
     }
     func_args.arg(genCall(data_type, genInline(init)));
+    func_args.arg(genComputeBlockDim());
 
     indent() << genCall("blockReduce", template_args, func_args) << ";\n";
   }
@@ -1578,6 +1595,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
       func_args.arg(genInline(wop->writePredicate()));
     }
     func_args.arg(genStaticCast(data_type, 0));
+    func_args.arg(genComputeBlockDim());
 
     indent() << genCall("blockWelford", template_args, func_args) << ";\n";
   }
@@ -1781,6 +1799,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genCall(data_type, genInline(grop->init())));
     func_args.arg(genInline(grop->entrance_index()));
     func_args.arg(genInline(grop->entrances()));
+    func_args.arg(genComputeBlockDim());
 
     addProfileArguments(func_args, grop);
 
@@ -1915,6 +1934,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(read_pred).arg(write_pred);
     // init_val
     func_args.arg(genCall("LocalTuple", data_type, genInline(grop->init())));
+    // block_dim
+    func_args.arg(genComputeBlockDim());
     // reduction_op
     func_args.arg(genReductionOp(op_type, out->dtype()));
 
@@ -1971,6 +1992,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     }
     // Init val
     func_args.arg(genCall(data_type, genInline(grop->initVal(0))));
+    // block_dim
+    func_args.arg(genComputeBlockDim());
 
     addProfileArguments(func_args, grop);
 
@@ -2059,6 +2082,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
 
     func_args.arg(genInline(grouped_grop->entrance_index()));
     func_args.arg(genInline(grouped_grop->entrances()));
+    func_args.arg(genComputeBlockDim());
 
     addProfileArguments(func_args, grouped_grop);
 
@@ -2271,6 +2295,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genCall("ConstRefTuple", types, inputs));
     func_args.arg(genCall("VolatilePtrTuple", types, work_bufs));
     func_args.arg(genCall("LocalTuple", types, init_vals));
+    func_args.arg(genComputeBlockDim());
 
     // global_sync_buffer
     const auto sync_buffer =
@@ -2407,6 +2432,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
     func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
     func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
+    // block_dim
+    func_args.arg(genComputeBlockDim());
     // work buffer
     func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
     func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
@@ -2498,6 +2525,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genVariableNameConvertAlignedArray(input.get(1)));
     func_args.arg(genVariableNameConvertAlignedArray(input.get(2)))
         .append("[0]");
+    // block_dim
+    func_args.arg(genComputeBlockDim());
 
     // global buf
     for (const auto i : c10::irange(3)) {
@@ -2652,6 +2681,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(genStaticCast(data_type, 0));
     func_args.arg(genInline(gwop->entrance_index()));
     func_args.arg(genInline(gwop->entrances()));
+    func_args.arg(genComputeBlockDim());
 
     indent() << genCall("welford::gridWelford", template_args, func_args)
              << ";\n";
@@ -2751,6 +2781,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
     func_args.arg(read_pred).arg(write_pred);
     // init_val
     func_args.arg(genCall("LocalTuple", data_type_args, init_args));
+    // block_dim
+    func_args.arg(genComputeBlockDim());
     // reduction_op
     func_args.arg(genTemplate(
         "welfordCombine", ArgumentBuilder().arg(data_type).arg(index_type)));
@@ -2877,6 +2909,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
       func_args.arg(genInline(write_pred));
     }
     func_args.arg(genCall(data_type, genInline(init)));
+    func_args.arg(genComputeBlockDim());
 
     indent() << genCall("blockIterGroupedYdimReduce", template_args, func_args)
              << ";\n";
@@ -3315,6 +3348,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
         .append(sync_idx)
         .append("]");
     sync_call_args.arg(sync_segment_size);
+    sync_call_args.arg(genComputeBlockDim());
 
     auto sync_call =
         genCall("grid_sync::sync", sync_call_template_parms, sync_call_args);
diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h
index f618de16021..e2e9de423c1 100644
--- a/csrc/parallel_dimension_map.h
+++ b/csrc/parallel_dimension_map.h
@@ -61,6 +61,11 @@ class ParallelDimensionMap {
   //! buffer tensors.
   Val* getNumComputeThreadsEachBlock() const;
 
+  //! Get if the kernel uses warp specialization
+  bool hasWarpSpecialization() const {
+    return !warp_specialized_types_.empty();
+  }
+
   bool has(ParallelType pt) const {
     return dim_map_.count(pt) > 0;
   }
diff --git a/runtime/basic_type_traits.cu b/runtime/basic_type_traits.cu
index 98eb2695ce1..b2f299cbd73 100644
--- a/runtime/basic_type_traits.cu
+++ b/runtime/basic_type_traits.cu
@@ -47,6 +47,8 @@ template <class _Tp, class _Up>
 struct is_same : public false_type {};
 template <class _Tp>
 struct is_same<_Tp, _Tp> : public true_type {};
+template <class T, class U>
+constexpr bool is_same_v = is_same<T, U>::value;
 
 // is_integral, for some types.
 template <class _Tp>
diff --git a/runtime/block_reduction.cu b/runtime/block_reduction.cu
index 817093eb63b..648b0b18714 100644
--- a/runtime/block_reduction.cu
+++ b/runtime/block_reduction.cu
@@ -21,7 +21,8 @@ template <
     bool Z_REDUCE,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void blockReduce(
     T& out,
     const T& inp_val,
@@ -29,28 +30,32 @@ __device__ void blockReduce(
     T* shared_mem,
     bool read_pred,
     bool write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // If this thread will output a final result
   bool should_write =
       index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(threadIdx);
 
   // Size of the reduction segments
   unsigned int reduction_size =
-      index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(blockDim);
+      index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim);
 
   // Index into the reduction segment
   unsigned int reduction_tid =
       index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Index of the reduction segment
   unsigned int reduction_idx =
       index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // number of reductions per block
   unsigned int reduction_num =
-      index_utils::maskedSize<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(blockDim);
+      index_utils::maskedSize<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(block_dim);
 
   // smem_offset is the offset into shared memory for the current thread.
   // To ensure coalesced access to shared memory, we need to ensure
@@ -65,15 +70,15 @@ __device__ void blockReduce(
   // To avoid this, we should always use the offset based on the indexing of
   // threads within a block.
   // Offset into smem for the current thread
-  unsigned int smem_offset = threadIdx.x + threadIdx.y * blockDim.x +
-      threadIdx.z * blockDim.x * blockDim.y;
+  unsigned int smem_offset = threadIdx.x + threadIdx.y * block_dim.x +
+      threadIdx.z * block_dim.x * block_dim.y;
 
   // The peer stride represents the distance between the current element and its
   // nearest reduction peer. It depends on the reduction dimension. A reduction
   // peer refers to elements that belong to the same reduction segment. For
   // example, if the reduction is across TIDy, all the elements in the same
   // column (with the same TIDx) are considered peers of each other. The
-  // distance between an element and its nearest peer is blockDim.x.
+  // distance between an element and its nearest peer is block_dim.x.
   constexpr int num_redu_dims = (int)X_REDUCE + (int)Y_REDUCE + (int)Z_REDUCE;
   constexpr bool xz_reduce = (num_redu_dims == 2 && !Y_REDUCE);
   // reduction in 3 dimensions, XYZ, stride is 1
@@ -82,27 +87,27 @@ __device__ void blockReduce(
     // Reduction only in 1 dimension, X or Y or Z
     // e.g. inner or outer reduction
     // If X_REDUCE, reducing in neighbor cols in smem, peer_stride is 1
-    // If Y_REDUCE, reducing in neighbor rows in smem, peer_stride is blockDim.x
-    // If Z_REDUCE, reducing in neighbor planes in smem, peer_stride is
-    // blockDim.x * blockDim.y
+    // If Y_REDUCE, reducing in neighbor rows in smem, peer_stride is
+    // block_dim.x If Z_REDUCE, reducing in neighbor planes in smem, peer_stride
+    // is block_dim.x * block_dim.y
     peer_stride = X_REDUCE ? 1
-        : Y_REDUCE         ? blockDim.x
-                           : blockDim.x * blockDim.y;
+        : Y_REDUCE         ? block_dim.x
+                           : block_dim.x * block_dim.y;
   } else if (num_redu_dims == 2) {
     // Reduction in 2 dimensions, only one dimension is not reduced, !X, !Y, !Z
     // If !Z_REDUCE, merge XY, reducing neighbor cols, peer_stride is 1
-    // If !X_REDUCE, merge ZY, reducing neighbor rows, peer_stride is blockDim.x
-    // If !Y_REDUCE, if blockDim.y == 1, merge XZ, peer_stride is 1.
-    // otherwise, needs carefully calculate offset to the reduction peer:
+    // If !X_REDUCE, merge ZY, reducing neighbor rows, peer_stride is
+    // block_dim.x If !Y_REDUCE, if block_dim.y == 1, merge XZ, peer_stride
+    // is 1. otherwise, needs carefully calculate offset to the reduction peer:
     // (1) redu_offset = reduction_tid + tree_fold_factor
-    // (2) idz = redu_offset / blockDim.x
-    // (3) idx = redu_offset % blockDim.x
-    // (4) smem_offset = idx + threadIdx.y * blockDim.x + idz * blockDim.x *
-    // blockDim.y
+    // (2) idz = redu_offset / block_dim.x
+    // (3) idx = redu_offset % block_dim.x
+    // (4) smem_offset = idx + threadIdx.y * block_dim.x + idz * block_dim.x *
+    // block_dim.y
     if (!Y_REDUCE) {
       peer_stride = 1;
     } else {
-      peer_stride = !Z_REDUCE ? 1 : blockDim.x;
+      peer_stride = !Z_REDUCE ? 1 : block_dim.x;
     }
   }
 
@@ -112,41 +117,41 @@ __device__ void blockReduce(
   } else {
     shared_mem[smem_offset] = init_val;
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Reduce down to nearest power of 2 for the tree reduction:
   int np2 = 1 << (31 - __clz(reduction_size));
   if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) {
     int peer_offset = smem_offset + np2 * peer_stride;
     if constexpr (xz_reduce) {
-      if (blockDim.y > 1) {
+      if (block_dim.y > 1) {
         int redu_offset = reduction_tid + np2;
-        int idz = redu_offset / blockDim.x;
-        int idx = redu_offset % blockDim.x;
+        int idz = redu_offset / block_dim.x;
+        int idx = redu_offset % block_dim.x;
         peer_offset =
-            idx + threadIdx.y * blockDim.x + idz * blockDim.x * blockDim.y;
+            idx + threadIdx.y * block_dim.x + idz * block_dim.x * block_dim.y;
       }
     }
     reduction_op(shared_mem[smem_offset], shared_mem[peer_offset]);
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // loop peel the final iteration to save one syncthread for the end
   for (int factor = np2 / 2; factor > 1; factor >>= 1) {
     if (reduction_tid < factor) {
       int peer_offset = smem_offset + factor * peer_stride;
       if constexpr (xz_reduce) {
-        if (blockDim.y > 1) {
+        if (block_dim.y > 1) {
           int redu_offset = reduction_tid + factor;
-          int idz = redu_offset / blockDim.x;
-          int idx = redu_offset % blockDim.x;
+          int idz = redu_offset / block_dim.x;
+          int idx = redu_offset % block_dim.x;
           peer_offset =
-              idx + threadIdx.y * blockDim.x + idz * blockDim.x * blockDim.y;
+              idx + threadIdx.y * block_dim.x + idz * block_dim.x * block_dim.y;
         }
       }
       reduction_op(shared_mem[smem_offset], shared_mem[peer_offset]);
     }
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   }
 
   if (should_write && write_pred) {
@@ -157,7 +162,7 @@ __device__ void blockReduce(
     }
     out = result;
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 // Use the same pred for both reads and writes
@@ -167,14 +172,19 @@ template <
     bool Z_REDUCE,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void blockReduce(
     T& out,
     const T& inp_val,
     Func reduction_op,
     T* shared_mem,
     bool read_write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, Aligned, T, Func>(
       out,
       inp_val,
@@ -182,7 +192,8 @@ __device__ void blockReduce(
       shared_mem,
       read_write_pred,
       read_write_pred,
-      init_val);
+      init_val,
+      block_dim);
 }
 
 // Each thread in the iteration dimension processes N elements
@@ -195,7 +206,8 @@ template <
     bool Aligned,
     int N, // Number of elements per input array
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void blockIterGroupedYdimReduce(
     T out[N],
     const T inp_val[N],
@@ -203,14 +215,18 @@ __device__ void blockIterGroupedYdimReduce(
     T* shared_mem,
     bool read_pred,
     bool write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // N should be a valid vectorization factor
   static_assert(
       N == 2 || N == 4 || N == 8 || N == 16,
       "N should be a valid vectorization factor, one of (2, 4, 8, 16)!");
 
   bool should_write = threadIdx.y == 0;
-  unsigned int reduction_size = blockDim.y;
+  unsigned int reduction_size = block_dim.y;
   unsigned int reduction_tid = threadIdx.y;
 
   // In shared memory, each row has 128 bytes, if sizeof(T) * N = 32 bytes, each
@@ -228,11 +244,12 @@ __device__ void blockIterGroupedYdimReduce(
 
   // assume TIDy is the reduction dimension, TIDx is the iteration dimension
   // TIDz is not used
-  unsigned int peer_stride = elements_per_load * blockDim.x;
+  unsigned int peer_stride = elements_per_load * block_dim.x;
 
-  unsigned int smem_offset_inter = blockDim.x * blockDim.y * elements_per_load;
+  unsigned int smem_offset_inter =
+      block_dim.x * block_dim.y * elements_per_load;
   unsigned int smem_offset_intra =
-      (threadIdx.y * blockDim.x + threadIdx.x) * elements_per_load;
+      (threadIdx.y * block_dim.x + threadIdx.x) * elements_per_load;
 
 // load to [total_loads] sections of shared memory
 #pragma unroll
@@ -241,7 +258,7 @@ __device__ void blockIterGroupedYdimReduce(
         shared_mem + smem_offset_inter * i + smem_offset_intra,
         const_cast<T*>(inp_val) + i * elements_per_load);
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Reduce down to nearest power of 2 for the tree reduction:
   // Perform parallel reduction for each element in the array
@@ -272,7 +289,7 @@ __device__ void blockIterGroupedYdimReduce(
           shared_mem + self_offset, self + i * elements_per_load);
     }
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Tree reduction
   for (int factor = np2 / 2; factor > 1; factor >>= 1) {
@@ -302,7 +319,7 @@ __device__ void blockIterGroupedYdimReduce(
             shared_mem + self_offset, self + i * elements_per_load);
       }
     }
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   }
 
   // last reduction
@@ -347,7 +364,7 @@ __device__ void blockIterGroupedYdimReduce(
       out[i] = result[i];
     }
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 // Use the same pred for both reads and writes
@@ -355,14 +372,19 @@ template <
     bool Aligned,
     int N, // Number of elements per input array
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void blockIterGroupedYdimReduce(
     T out[N],
     const T inp_val[N],
     Func reduction_op,
     T* shared_mem,
     bool read_write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   blockIterGroupedYdimReduce<Aligned, N, T, Func>(
       out,
       inp_val,
@@ -370,5 +392,6 @@ __device__ void blockIterGroupedYdimReduce(
       shared_mem,
       read_write_pred,
       read_write_pred,
-      init_val);
+      init_val,
+      block_dim);
 }
diff --git a/runtime/block_sync_default.cu b/runtime/block_sync_default.cu
index 2a3048c1a8b..1c865320fa6 100644
--- a/runtime/block_sync_default.cu
+++ b/runtime/block_sync_default.cu
@@ -6,18 +6,34 @@
  */
 // clang-format on
 
+// Basically just blockDim, but wrapped as a struct so that we have a mechanism
+// to know at compile time that whether we are just using blockDim or some
+// custom value. For a kernel without warp specialization, we just use blockDim,
+// but for a kernel with warp specialization, we use a custom block_dim whose
+// dimension are the dimensions of the compute warps.
+struct DefaultBlockDim {
+  const uint32_t x, y, z;
+  __device__ DefaultBlockDim() : x(blockDim.x), y(blockDim.y), z(blockDim.z) {}
+  __device__ operator dim3() const {
+    return blockDim;
+  }
+};
+
 // Default block synchronization. Just use __barrier_sync
 namespace block_sync {
 
 __forceinline__ __device__ void init() {}
 
 // Thread-block synchronization
-template <bool aligned>
-__forceinline__ __device__ void sync() {
+template <bool aligned, typename BlockDimT>
+__forceinline__ __device__ void sync(BlockDimT block_dim) {
   if constexpr (aligned) {
     __syncthreads();
-  } else {
+  } else if constexpr (std::is_same_v<BlockDimT, DefaultBlockDim>) {
     __barrier_sync(0);
+  } else {
+    uint32_t num_threads = block_dim.x * block_dim.y * block_dim.z;
+    asm volatile("bar.sync 0, %0;" : : "r"(num_threads) : "memory");
   }
 }
 
diff --git a/runtime/block_welford_outer.cu b/runtime/block_welford_outer.cu
index 3b758fe61d3..ded37e75098 100644
--- a/runtime/block_welford_outer.cu
+++ b/runtime/block_welford_outer.cu
@@ -56,11 +56,21 @@ namespace impl {
 // registers than just returing the output. Results would vary
 // depending on compiler versions, but it seems safer to return outputs
 // as a new value.
-template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY>
+template <
+    bool Aligned,
+    int NumVals,
+    typename DataType,
+    int BDIMX,
+    int BDIMY,
+    typename BlockDimT>
 __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter(
     DataType* inp_avg,
     DataType* inp_var,
     nvfuser_index_t inp_N,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     DataType* smem) {
   constexpr int num_warps = BDIMX * BDIMY / 32;
   static_assert(num_warps >= 1, "There must be at least a single warp");
@@ -188,7 +198,7 @@ __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter(
     }
   }
 
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // The next step is to let each thread of a warp independently
   // accumulate the partial results on the shared memory
@@ -245,7 +255,7 @@ __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter(
     }
   }
 
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Nothing to do for warps whose wid is larger than NunVals
   if (wid >= NumVals) {
diff --git a/runtime/broadcast.cu b/runtime/broadcast.cu
index 127fc6ea762..2bc82b9bd97 100644
--- a/runtime/broadcast.cu
+++ b/runtime/broadcast.cu
@@ -16,30 +16,40 @@ namespace broadcast {
 // inp_val: Per-thread source value. Only valid when the thread is a source.
 // out: Per-thread output location
 //
-template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, bool Aligned, typename T>
+template <
+    bool X_THREAD,
+    bool Y_THREAD,
+    bool Z_THREAD,
+    bool Aligned,
+    typename T,
+    typename BlockDimT>
 __device__ void blockBroadcast(
     T& out,
     const T& inp_val,
     T* shared_mem,
-    bool read_write_pred) {
+    bool read_write_pred,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) &&
       (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0);
 
   const auto shared_offset =
       index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   if (has_valid_data && read_write_pred) {
     shared_mem[shared_offset] = inp_val;
   }
 
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   if (read_write_pred) {
     out = shared_mem[shared_offset];
   }
 
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 } // namespace broadcast
diff --git a/runtime/fused_reduction.cu b/runtime/fused_reduction.cu
index 731403067d9..9967078fc24 100644
--- a/runtime/fused_reduction.cu
+++ b/runtime/fused_reduction.cu
@@ -204,6 +204,7 @@ template <
     bool FORWARD_PROTECT_SMEM,
     bool Aligned,
     typename LocalTupleT,
+    typename BlockDimT,
     typename... Funcs>
 struct BlockReduceEach {
   __inline__ __device__ static void reduce(
@@ -215,9 +216,20 @@ struct BlockReduceEach {
       int num_threads_per_reduction,
       int num_elements_per_reduction,
       int reduction_idx,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       Funcs... funcs) {
     // Finish the reduction of each tuple value with a smaller offset
-    BlockReduceEach<idx - 1, BROADCAST, true, Aligned, LocalTupleT, Funcs...>::
+    BlockReduceEach<
+        idx - 1,
+        BROADCAST,
+        true,
+        Aligned,
+        LocalTupleT,
+        BlockDimT,
+        Funcs...>::
         reduce(
             block_result,
             partial_result,
@@ -227,6 +239,7 @@ struct BlockReduceEach {
             num_threads_per_reduction,
             num_elements_per_reduction,
             reduction_idx,
+            block_dim,
             funcs...);
 
     if (num_elements_per_reduction == 1) {
@@ -252,7 +265,7 @@ struct BlockReduceEach {
       copyTuple(shared_buf, smem_offset, block_result_i);
     }
 
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     if (tid_in_reduction < np2 &&
         tid_in_reduction + np2 < num_elements_per_reduction) {
@@ -265,7 +278,7 @@ struct BlockReduceEach {
     }
 
     // Always sync when communicating across smem
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     // Reduce down to 2 values, last thread will do the final reduction and
     // can save a syncthreads this way
@@ -278,7 +291,7 @@ struct BlockReduceEach {
             smem_offset + factor,
             funcs...);
       }
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
 
     copyTuple(block_result_i, shared_buf, smem_offset);
@@ -300,7 +313,7 @@ struct BlockReduceEach {
       }
 
       // Sync threads to make sure result is in smem
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
 
       copyTuple(
           block_result_i,
@@ -311,7 +324,7 @@ struct BlockReduceEach {
     block_result.val<idx>(0) = block_result_i.val<0>(0);
 
     if (FORWARD_PROTECT_SMEM) {
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
   }
 };
@@ -322,6 +335,7 @@ template <
     bool FORWARD_PROTECT_SMEM,
     bool Aligned,
     typename LocalTupleT,
+    typename BlockDimT,
     typename... Funcs>
 struct BlockReduceEach<
     -1,
@@ -329,6 +343,7 @@ struct BlockReduceEach<
     FORWARD_PROTECT_SMEM,
     Aligned,
     LocalTupleT,
+    BlockDimT,
     Funcs...> {
   __inline__ __device__ static void reduce(
       LocalTupleT& block_result,
@@ -339,6 +354,10 @@ struct BlockReduceEach<
       int num_threads_per_reduction,
       int num_elements_per_reduction,
       int reduction_idx,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       Funcs... funcs) {}
 };
 
@@ -360,6 +379,7 @@ template <
     bool FORWARD_PROTECT_SMEM,
     bool Aligned,
     typename LocalTupleT,
+    typename BlockDimT,
     typename... Funcs>
 __inline__ __device__ void blockReduceEach(
     LocalTupleT& block_result,
@@ -370,6 +390,10 @@ __inline__ __device__ void blockReduceEach(
     int num_threads_per_reduction,
     int num_elements_per_reduction,
     int reduction_idx,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     Funcs... reduction_ops) {
   BlockReduceEach<
       LocalTupleT::num_vals - 1,
@@ -377,6 +401,7 @@ __inline__ __device__ void blockReduceEach(
       FORWARD_PROTECT_SMEM,
       Aligned,
       LocalTupleT,
+      BlockDimT,
       Funcs...>::
       reduce(
           block_result,
@@ -387,6 +412,7 @@ __inline__ __device__ void blockReduceEach(
           num_threads_per_reduction,
           num_elements_per_reduction,
           reduction_idx,
+          block_dim,
           reduction_ops...);
 }
 
@@ -466,7 +492,7 @@ class ParallelReduce {
   // reduceGroup does not support Welford-style reductions that reduce
   // all values of a tuple together, so this is the only entry point
   // for Welford for now.
-  template <bool Aligned, typename Func, typename... Types>
+  template <bool Aligned, typename Func, typename BlockDimT, typename... Types>
   __device__ __inline__ void reduce(
       RefTuple<Types...> out,
       const ConstRefTuple<Types...>& inp,
@@ -477,10 +503,14 @@ class ParallelReduce {
       bool read_pred, // Prevent reading from out of bounds memory
       bool write_pred, // Prevent from writing out of bounds
       const LocalTuple<Types...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       Func reduction_op);
 
   //! Profiled version
-  template <bool Aligned, typename Func, typename... Types>
+  template <bool Aligned, typename Func, typename BlockDimT, typename... Types>
   __device__ __inline__ void reduce(
       RefTuple<Types...> out,
       const ConstRefTuple<Types...>& inp,
@@ -491,6 +521,10 @@ class ParallelReduce {
       bool read_pred, // Prevent reading from out of bounds memory
       bool write_pred, // Prevent from writing out of bounds
       const LocalTuple<Types...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       Func reduction_op,
       int64_t& cycles,
       int64_t& count);
@@ -505,6 +539,7 @@ class ParallelReduce {
   //! no need to accumulate into the out parameter.
   template <
       bool Aligned,
+      typename BlockDimT,
       typename... DataTypes,
       typename... Funcs,
       typename... BoolTypes>
@@ -513,6 +548,10 @@ class ParallelReduce {
       const ConstRefTuple<DataTypes...>& inp,
       VolatilePtrTuple<DataTypes...> global_work_buffer,
       const LocalTuple<DataTypes...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       int64_t* global_sync_buffer,
       void* shared_mem,
       const LocalTuple<BoolTypes...>& read_preds,
@@ -522,6 +561,7 @@ class ParallelReduce {
   //! Profiled version
   template <
       bool Aligned,
+      typename BlockDimT,
       typename... DataTypes,
       typename... Funcs,
       typename... BoolTypes>
@@ -530,6 +570,10 @@ class ParallelReduce {
       const ConstRefTuple<DataTypes...>& inp,
       VolatilePtrTuple<DataTypes...> global_work_buffer,
       const LocalTuple<DataTypes...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       int64_t* global_sync_buffer,
       void* shared_mem,
       const LocalTuple<BoolTypes...>& read_preds,
@@ -551,7 +595,12 @@ class ParallelReduce {
   // simplicity. In practice, it should be really uncommon to group
   // welford ops with different data types, so this restriction
   // shouldn't be an issue.
-  template <bool Aligned, int NumArgs, typename DataType, typename IndexType>
+  template <
+      bool Aligned,
+      int NumArgs,
+      typename DataType,
+      typename IndexType,
+      typename BlockDimT>
   __device__ __inline__ void welfordGroup(
       typename MakeRefTuple<NumArgs, DataType>::type out_avg,
       typename MakeRefTuple<NumArgs, DataType>::type out_var,
@@ -562,6 +611,10 @@ class ParallelReduce {
       const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
       const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
       const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       typename MakeVolatilePtrTuple<NumArgs, DataType>::type
           global_work_buffer_avg,
       typename MakeVolatilePtrTuple<NumArgs, DataType>::type
@@ -574,7 +627,12 @@ class ParallelReduce {
       const typename MakeLocalTuple<NumArgs, bool>::type& write_preds);
 
   //! Profiled version
-  template <bool Aligned, int NumArgs, typename DataType, typename IndexType>
+  template <
+      bool Aligned,
+      int NumArgs,
+      typename DataType,
+      typename IndexType,
+      typename BlockDimT>
   __device__ __inline__ void welfordGroup(
       typename MakeRefTuple<NumArgs, DataType>::type out_avg,
       typename MakeRefTuple<NumArgs, DataType>::type out_var,
@@ -585,6 +643,10 @@ class ParallelReduce {
       const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
       const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
       const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       typename MakeVolatilePtrTuple<NumArgs, DataType>::type
           global_work_buffer_avg,
       typename MakeVolatilePtrTuple<NumArgs, DataType>::type
@@ -601,7 +663,13 @@ class ParallelReduce {
   // This is highly specific to the outer-reduction pattern. All the
   // assumptions should be asserted with static_assert at the begging of
   // the fuction.
-  template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY>
+  template <
+      bool Aligned,
+      int NumVals,
+      typename DataType,
+      int BDIMX,
+      int BDIMY,
+      typename BlockDimT>
   __device__ __inline__ void welfordGroupOuter(
       DataType out_avg[NumVals],
       DataType out_var[NumVals],
@@ -609,6 +677,10 @@ class ParallelReduce {
       const DataType in_avg[NumVals],
       const DataType in_var[NumVals],
       nvfuser_index_t in_N,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       DataType* global_buf_avg,
       DataType* global_buf_var,
       nvfuser_index_t* global_buf_N,
@@ -616,7 +688,13 @@ class ParallelReduce {
       int64_t* global_sync_buffer);
 
   // Profiled version
-  template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY>
+  template <
+      bool Aligned,
+      int NumVals,
+      typename DataType,
+      int BDIMX,
+      int BDIMY,
+      typename BlockDimT>
   __device__ __inline__ void welfordGroupOuter(
       DataType out_avg[NumVals],
       DataType out_var[NumVals],
@@ -624,6 +702,10 @@ class ParallelReduce {
       const DataType in_avg[NumVals],
       const DataType in_var[NumVals],
       nvfuser_index_t in_N,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       DataType* global_buf_avg,
       DataType* global_buf_var,
       nvfuser_index_t* global_buf_N,
@@ -651,12 +733,17 @@ class ParallelReduce {
   template <
       bool BLOCK_BROADCAST,
       bool Aligned,
+      typename BlockDimT,
       typename... DataTypes,
       typename... Funcs,
       typename... BoolTypes>
   __device__ __inline__ static LocalTuple<DataTypes...> reduceGroupBlock(
       const ConstRefTuple<DataTypes...>& inp,
       const LocalTuple<DataTypes...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       void* shared_mem,
       const LocalTuple<BoolTypes...>& read_preds,
       bool block_reduce_participate,
@@ -668,6 +755,7 @@ class ParallelReduce {
   //! but it isn't synchronized when returning from this function.
   template <
       bool Aligned,
+      typename BlockDimT,
       typename... DataTypes,
       typename... Funcs,
       typename... BoolTypes>
@@ -675,6 +763,10 @@ class ParallelReduce {
       RefTuple<DataTypes...>& out,
       const VolatilePtrTuple<DataTypes...>& global_work_buffer,
       const LocalTuple<DataTypes...>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       void* shared_mem,
       nvfuser_index_t block_red_idx_offset,
       nvfuser_index_t num_thread_iters,
@@ -692,21 +784,35 @@ class ParallelReduce {
       bool Aligned,
       int NumVals,
       typename DataType,
-      typename IndexType>
+      typename IndexType,
+      typename BlockDimT>
   __device__ __inline__ static void welfordGroupBlock(
       LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
       const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       PtrTuple<DataType, DataType, IndexType> shared_buf,
       const typename MakeLocalTuple<NumVals, bool>::type& read_preds,
       bool block_reduce_participate);
 
   //! Welford version of reduceGrouplLastBlock
-  template <bool Aligned, int NumVals, typename DataType, typename IndexType>
+  template <
+      bool Aligned,
+      int NumVals,
+      typename DataType,
+      typename IndexType,
+      typename BlockDimT>
   __device__ __inline__ static void welfordGroupLastBlock(
       RefWelfordTripletTuple<NumVals, DataType, IndexType>& out,
       const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>&
           global_work_buffer,
       const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim,
       PtrTuple<DataType, DataType, IndexType> shared_buf,
       nvfuser_index_t block_red_idx_offset,
       nvfuser_index_t num_thread_iters,
@@ -729,7 +835,7 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, typename Func, typename... Types>
+template <bool Aligned, typename Func, typename BlockDimT, typename... Types>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -749,6 +855,10 @@ __device__ __inline__ void ParallelReduce<
         bool read_pred, // Prevent reading from out of bounds memory
         bool write_pred, // Prevent from writing out of bounds
         const LocalTuple<Types...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         Func reduction_op) {
   // If no reduction needed, just return input
   if (!BLOCK_REDUCE && !GRID_REDUCE) {
@@ -785,14 +895,14 @@ __device__ __inline__ void ParallelReduce<
     // to number of threads
     int block_reduction_size = index_utils::
         maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
-            blockDim);
+            block_dim);
 
     // Index in the reduction segment, can be an int since it's limited to
     // number of threads
     int tid_in_block_reduction = index_utils::maskedOffset<
         isReduce(X_THREAD),
         isReduce(Y_THREAD),
-        isReduce(Z_THREAD)>(threadIdx, blockDim);
+        isReduce(Z_THREAD)>(threadIdx, block_dim);
 
     // ID of the block reduction this thread is participating in
     //
@@ -802,7 +912,7 @@ __device__ __inline__ void ParallelReduce<
     // dimension
     int block_reduction_idx = index_utils::
         maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
 
     // Shared memory buffer is 2D
     // [iter dimension, reduction dimension]
@@ -817,7 +927,7 @@ __device__ __inline__ void ParallelReduce<
     }
 
     // Sync to make sure smem is completely initialized
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     // Round reduction size down to nearest power of 2
     int np2 = 1 << (31 - __clz(block_reduction_size));
@@ -834,7 +944,7 @@ __device__ __inline__ void ParallelReduce<
     }
 
     // Always need to sync while operating on shared memory
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     // Reduce down until 2 values, leaving 2 values allows us to manually
     // perform the last reduction and avoid a syncthreads
@@ -847,7 +957,7 @@ __device__ __inline__ void ParallelReduce<
             block_reduce_smem_offset + factor,
             reduction_op);
       }
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
 
     // Accumulate that last valid result
@@ -883,7 +993,7 @@ __device__ __inline__ void ParallelReduce<
         }
 
         // Sync threads to make sure result is in smem
-        block_sync::sync<Aligned>();
+        block_sync::sync<Aligned>(block_dim);
         // If the thread is participating, and is not attempting to write out
         // of bounds, return the broadcasted value.
         if (block_reduce_participate && write_pred) {
@@ -898,7 +1008,7 @@ __device__ __inline__ void ParallelReduce<
       //
       // This could be avoided in some cases if we added thread syncs from
       // block reductions in the syncthread insertion pass.
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
       return;
     }
   }
@@ -929,7 +1039,7 @@ __device__ __inline__ void ParallelReduce<
                 gridDim) *
         index_utils::
             maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-                blockDim) *
+                block_dim) *
         grid_red_size;
     global_work_buffer += global_buffer_size;
   }
@@ -948,13 +1058,13 @@ __device__ __inline__ void ParallelReduce<
   // How many grid reductions have to be performed, in the block dimension
   const auto num_thread_iters = index_utils::
       maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          blockDim);
+          block_dim);
 
   // Which grid reduction does this thread participate in, in the block
   // dimension
   const auto thread_red_idx_offset = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // 3D buffer of reductions:
   //    [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
@@ -989,7 +1099,10 @@ __device__ __inline__ void ParallelReduce<
         isReduce(Z_BLOCK),
         PERSISTENT_REDUCTION,
         Aligned>(
-        global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
+        global_sync_buffer[block_red_idx_offset],
+        grid_red_size,
+        last_block,
+        block_dim);
   }
 
   // -- START BLOCK CLEANUP -- //
@@ -1011,12 +1124,12 @@ __device__ __inline__ void ParallelReduce<
     int tid_in_block_reduction_2 = index_utils::maskedOffset<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(threadIdx, blockDim);
+        activeNotIter(Z_THREAD)>(threadIdx, block_dim);
 
     int block_reduction_size_2 = index_utils::maskedSize<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(blockDim);
+        activeNotIter(Z_THREAD)>(block_dim);
 
     // 3D buffer of reductions:
     //    [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
@@ -1049,7 +1162,7 @@ __device__ __inline__ void ParallelReduce<
     // Which block reduction this thread is participating in
     int block_reduction_idx = index_utils::
         maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
 
     // Offset in smem for this thread's result
     auto smem_offset =
@@ -1064,7 +1177,7 @@ __device__ __inline__ void ParallelReduce<
       copyTuple(shared_buf, smem_offset, last_block_result);
     }
 
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     if (tid_in_block_reduction_2 < np2 &&
         tid_in_block_reduction_2 + np2 <
@@ -1078,7 +1191,7 @@ __device__ __inline__ void ParallelReduce<
     }
 
     // Always sync when communicating across smem
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     // Reduce down to 2 values, last thread will do the final reduction and
     // can save a syncthreads this way
@@ -1091,7 +1204,7 @@ __device__ __inline__ void ParallelReduce<
             smem_offset + factor,
             reduction_op);
       }
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
 
     // If this thread in each block has the final result before broadcasting
@@ -1113,7 +1226,7 @@ __device__ __inline__ void ParallelReduce<
     if (grid_reduce_participate && PERSISTENT_REDUCTION) {
       // If persistent reduction, always broadcast reduced values
       copyTuple(shared_buf, smem_offset, last_block_result);
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
       if (write_pred && block_reduce_participate) {
         copyTuple(
             out, shared_buf, block_reduction_idx * block_reduction_size_2);
@@ -1133,7 +1246,7 @@ __device__ __inline__ void ParallelReduce<
       }
     }
     // Forward protect the smem used in this reduction
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   }
 }
 
@@ -1147,7 +1260,7 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, typename Func, typename... Types>
+template <bool Aligned, typename Func, typename BlockDimT, typename... Types>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -1167,6 +1280,10 @@ __device__ __inline__ void ParallelReduce<
         bool read_pred, // Prevent reading from out of bounds memory
         bool write_pred, // Prevent from writing out of bounds
         const LocalTuple<Types...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         Func reduction_op,
         int64_t& cycles,
         int64_t& count) {
@@ -1186,6 +1303,7 @@ __device__ __inline__ void ParallelReduce<
       read_pred,
       write_pred,
       init_val,
+      block_dim,
       reduction_op);
 
   if (isLastBlockInGrid() &&
@@ -1206,6 +1324,7 @@ template <
     bool BROADCAST>
 template <
     bool Aligned,
+    typename BlockDimT,
     typename... DataTypes,
     typename... Funcs,
     typename... BoolTypes>
@@ -1223,6 +1342,10 @@ __device__ __inline__ void ParallelReduce<
         const ConstRefTuple<DataTypes...>& inp,
         VolatilePtrTuple<DataTypes...> global_work_buffer,
         const LocalTuple<DataTypes...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         int64_t* global_sync_buffer,
         void* shared_mem,
         const LocalTuple<BoolTypes...>& read_preds,
@@ -1257,6 +1380,7 @@ __device__ __inline__ void ParallelReduce<
   const auto block_result = reduceGroupBlock < !GRID_REDUCE && BROADCAST,
              Aligned > (inp,
                         init_val,
+                        block_dim,
                         shared_mem,
                         read_preds,
                         block_reduce_participate,
@@ -1273,7 +1397,7 @@ __device__ __inline__ void ParallelReduce<
     // forward-protect the smem buffer. This block sync is not
     // necessary when a grid reduction follows since a block sync is
     // done just before the grid sync.
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
     return;
   }
 
@@ -1309,13 +1433,13 @@ __device__ __inline__ void ParallelReduce<
   // How many grid reductions have to be performed, in the block dimension
   const auto num_thread_iters = index_utils::
       maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          blockDim);
+          block_dim);
 
   // Which grid reduction does this thread participate in, in the block
   // dimension
   const auto thread_red_idx_offset = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // 3D buffer of reductions:
   //    [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
@@ -1336,7 +1460,7 @@ __device__ __inline__ void ParallelReduce<
                 gridDim) *
         index_utils::
             maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-                blockDim) *
+                block_dim) *
         grid_red_size;
     global_work_buffer += global_buffer_size;
   }
@@ -1362,7 +1486,10 @@ __device__ __inline__ void ParallelReduce<
         isReduce(Z_BLOCK),
         PERSISTENT_REDUCTION,
         Aligned>(
-        global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
+        global_sync_buffer[block_red_idx_offset],
+        grid_red_size,
+        last_block,
+        block_dim);
   }
 
   // -- START BLOCK CLEANUP -- //
@@ -1370,6 +1497,7 @@ __device__ __inline__ void ParallelReduce<
       out,
       global_work_buffer,
       init_val,
+      block_dim,
       shared_mem,
       block_red_idx_offset,
       num_thread_iters,
@@ -1382,7 +1510,7 @@ __device__ __inline__ void ParallelReduce<
       funcs...);
 
   // Forward protect the smem buffer
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 template <
@@ -1396,6 +1524,7 @@ template <
     bool BROADCAST>
 template <
     bool Aligned,
+    typename BlockDimT,
     typename... DataTypes,
     typename... Funcs,
     typename... BoolTypes>
@@ -1413,6 +1542,10 @@ __device__ __inline__ void ParallelReduce<
         const ConstRefTuple<DataTypes...>& inp,
         VolatilePtrTuple<DataTypes...> global_work_buffer,
         const LocalTuple<DataTypes...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         int64_t* global_sync_buffer,
         void* shared_mem,
         const LocalTuple<BoolTypes...>& read_preds,
@@ -1432,6 +1565,7 @@ __device__ __inline__ void ParallelReduce<
       inp,
       global_work_buffer,
       init_val,
+      block_dim,
       global_sync_buffer,
       shared_mem,
       read_preds,
@@ -1457,6 +1591,7 @@ template <
 template <
     bool BLOCK_BROADCAST,
     bool Aligned,
+    typename BlockDimT,
     typename... DataTypes,
     typename... Funcs,
     typename... BoolTypes>
@@ -1472,6 +1607,10 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
     reduceGroupBlock(
         const ConstRefTuple<DataTypes...>& inp,
         const LocalTuple<DataTypes...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         void* shared_mem,
         const LocalTuple<BoolTypes...>& read_preds,
         bool block_reduce_participate,
@@ -1489,13 +1628,13 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
   // to number of threads
   const int block_reduction_size = index_utils::
       maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
-          blockDim);
+          block_dim);
 
   // Index in the reduction segment, can be an int since it's limited to
   // number of threads
   const int tid_in_block_reduction = index_utils::
       maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // ID of the block reduction this thread is participating in
   //
@@ -1505,7 +1644,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
   // dimension
   const int block_reduction_idx = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Do not protect the smem buffer as it's not always necessary.
   impl::blockReduceEach<
@@ -1513,6 +1652,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
       false,
       Aligned,
       LocalTuple<DataTypes...>,
+      BlockDimT,
       Funcs...>(
       block_result,
       block_result,
@@ -1522,6 +1662,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce<
       block_reduction_size,
       block_reduction_size,
       block_reduction_idx,
+      block_dim,
       funcs...);
 
   return block_result;
@@ -1538,6 +1679,7 @@ template <
     bool BROADCAST>
 template <
     bool Aligned,
+    typename BlockDimT,
     typename... DataTypes,
     typename... Funcs,
     typename... BoolTypes>
@@ -1554,6 +1696,10 @@ __device__ __inline__ void ParallelReduce<
         RefTuple<DataTypes...>& out,
         const VolatilePtrTuple<DataTypes...>& global_work_buffer,
         const LocalTuple<DataTypes...>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         void* shared_mem,
         nvfuser_index_t block_red_idx_offset,
         nvfuser_index_t num_thread_iters,
@@ -1583,12 +1729,12 @@ __device__ __inline__ void ParallelReduce<
     int tid_in_block_reduction = index_utils::maskedOffset<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(threadIdx, blockDim);
+        activeNotIter(Z_THREAD)>(threadIdx, block_dim);
 
     int block_reduction_size = index_utils::maskedSize<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(blockDim);
+        activeNotIter(Z_THREAD)>(block_dim);
 
     bool has_block_result = index_utils::maskedIsZero<
         activeNotIter(X_THREAD),
@@ -1620,13 +1766,14 @@ __device__ __inline__ void ParallelReduce<
     // Which block reduction this thread is participating in
     int block_reduction_idx = index_utils::
         maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
 
     impl::blockReduceEach<
         BROADCAST,
         false,
         Aligned,
         LocalTuple<DataTypes...>,
+        BlockDimT,
         Funcs...>(
         last_block_result,
         last_block_result,
@@ -1636,6 +1783,7 @@ __device__ __inline__ void ParallelReduce<
         block_reduction_size,
         min(grid_red_size, block_reduction_size),
         block_reduction_idx,
+        block_dim,
         reduction_ops...);
 
     copyTupleIf(
diff --git a/runtime/fused_welford_impl.cu b/runtime/fused_welford_impl.cu
index 314a8b405b5..eeeebc36bdc 100644
--- a/runtime/fused_welford_impl.cu
+++ b/runtime/fused_welford_impl.cu
@@ -72,7 +72,8 @@ template <
     bool Aligned,
     int NumVals,
     typename DataType,
-    typename IndexType>
+    typename IndexType,
+    typename BlockDimT>
 struct BlockWelfordEach {
   __inline__ __device__ static void reduce(
       LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
@@ -83,7 +84,11 @@ struct BlockWelfordEach {
       int tid_in_reduction,
       int num_threads_per_reduction,
       int num_elements_per_reduction,
-      int reduction_idx) {
+      int reduction_idx,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim) {
     // Finish the reduction of each tuple value with a smaller offset
     BlockWelfordEach<
         idx - 1,
@@ -92,7 +97,8 @@ struct BlockWelfordEach {
         Aligned,
         NumVals,
         DataType,
-        IndexType>::
+        IndexType,
+        BlockDimT>::
         reduce(
             block_result,
             partial_result,
@@ -101,7 +107,8 @@ struct BlockWelfordEach {
             tid_in_reduction,
             num_threads_per_reduction,
             num_elements_per_reduction,
-            reduction_idx);
+            reduction_idx,
+            block_dim);
 
     if (num_elements_per_reduction == 1) {
       if (has_block_result) {
@@ -125,7 +132,7 @@ struct BlockWelfordEach {
       copyTuple(shared_buf, smem_offset, block_result_i);
     }
 
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
     if (tid_in_reduction < np2 &&
         tid_in_reduction + np2 < num_elements_per_reduction) {
       impl::reduceTuple(
@@ -141,7 +148,7 @@ struct BlockWelfordEach {
     }
 
     // Always sync when communicating across smem
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     // Reduce down to 2 values, last thread will do the final reduction and
     // can save a syncthreads this way
@@ -154,7 +161,7 @@ struct BlockWelfordEach {
             smem_offset + factor,
             welfordCombine<DataType, IndexType>);
       }
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
 
     copyTuple(block_result_i, shared_buf, smem_offset);
@@ -180,7 +187,7 @@ struct BlockWelfordEach {
       }
 
       // Sync threads to make sure result is in smem
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
 
       copyTuple(
           block_result_i,
@@ -193,7 +200,7 @@ struct BlockWelfordEach {
     block_result.N.val<idx>(0) = block_result_i.val<2>(0);
 
     if (FORWARD_PROTECT_SMEM) {
-      block_sync::sync<Aligned>();
+      block_sync::sync<Aligned>(block_dim);
     }
   }
 };
@@ -205,7 +212,8 @@ template <
     bool Aligned,
     int NumVals,
     typename DataType,
-    typename IndexType>
+    typename IndexType,
+    typename BlockDimT>
 struct BlockWelfordEach<
     -1,
     BROADCAST,
@@ -213,7 +221,8 @@ struct BlockWelfordEach<
     Aligned,
     NumVals,
     DataType,
-    IndexType> {
+    IndexType,
+    BlockDimT> {
   __inline__ __device__ static void reduce(
       LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
       const LocalWelfordTripletTuple<NumVals, DataType, IndexType>&
@@ -223,7 +232,11 @@ struct BlockWelfordEach<
       int tid_in_reduction,
       int num_threads_per_reduction,
       int num_elements_per_reduction,
-      int reduction_idx) {}
+      int reduction_idx,
+      // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+      // there is no warp specialization in the kernel. If there is warp
+      // specialization, block_dim is the the dimension of the compute warps.
+      BlockDimT block_dim) {}
 };
 
 //! Welford version of blockReduceEach. Perform block-parallel Welford
@@ -234,7 +247,8 @@ template <
     bool Aligned,
     int NumVals,
     typename DataType,
-    typename IndexType>
+    typename IndexType,
+    typename BlockDimT>
 __inline__ __device__ void blockWelfordEach(
     LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
     const LocalWelfordTripletTuple<NumVals, DataType, IndexType>&
@@ -244,7 +258,11 @@ __inline__ __device__ void blockWelfordEach(
     int tid_in_reduction,
     int num_threads_per_reduction,
     int num_elements_per_reduction,
-    int reduction_idx) {
+    int reduction_idx,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   BlockWelfordEach<
       NumVals - 1,
       BROADCAST,
@@ -252,7 +270,8 @@ __inline__ __device__ void blockWelfordEach(
       Aligned,
       NumVals,
       DataType,
-      IndexType>::
+      IndexType,
+      BlockDimT>::
       reduce(
           block_result,
           partial_result,
@@ -261,7 +280,8 @@ __inline__ __device__ void blockWelfordEach(
           tid_in_reduction,
           num_threads_per_reduction,
           num_elements_per_reduction,
-          reduction_idx);
+          reduction_idx,
+          block_dim);
 }
 
 } // namespace impl
@@ -275,7 +295,12 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, int NumArgs, typename DataType, typename IndexType>
+template <
+    bool Aligned,
+    int NumArgs,
+    typename DataType,
+    typename IndexType,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -295,6 +320,10 @@ __device__ __inline__ void ParallelReduce<
         const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
         const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
         const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         typename MakeVolatilePtrTuple<NumArgs, DataType>::type
             global_work_buffer_avg,
         typename MakeVolatilePtrTuple<NumArgs, DataType>::type
@@ -338,7 +367,12 @@ __device__ __inline__ void ParallelReduce<
       NumArgs,
       DataType,
       IndexType>(
-      block_result, inp, shared_buf, read_preds, block_reduce_participate);
+      block_result,
+      inp,
+      block_dim,
+      shared_buf,
+      read_preds,
+      block_reduce_participate);
 
   // If block reduction only, save to out and exit
   if (!GRID_REDUCE) {
@@ -352,7 +386,7 @@ __device__ __inline__ void ParallelReduce<
     // forward-protect the smem buffer. This block sync is not
     // necessary when a grid reduction follows since a block sync is
     // done just before the grid sync.
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
     return;
   }
 
@@ -388,13 +422,13 @@ __device__ __inline__ void ParallelReduce<
   // How many grid reductions have to be performed, in the block dimension
   const auto num_thread_iters = index_utils::
       maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          blockDim);
+          block_dim);
 
   // Which grid reduction does this thread participate in, in the block
   // dimension
   const auto thread_red_idx_offset = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // 3D buffer of reductions:
   //    [reduction_offset(grid), iter_offset(grid), iter_offset(block)]
@@ -419,7 +453,7 @@ __device__ __inline__ void ParallelReduce<
                 gridDim) *
         index_utils::
             maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-                blockDim) *
+                block_dim) *
         grid_red_size;
     global_work_buffer += global_buffer_size;
   }
@@ -445,7 +479,10 @@ __device__ __inline__ void ParallelReduce<
         isReduce(Z_BLOCK),
         PERSISTENT_REDUCTION,
         Aligned>(
-        global_sync_buffer[block_red_idx_offset], grid_red_size, last_block);
+        global_sync_buffer[block_red_idx_offset],
+        grid_red_size,
+        last_block,
+        block_dim);
   }
 
   // -- START BLOCK CLEANUP -- //
@@ -454,6 +491,7 @@ __device__ __inline__ void ParallelReduce<
       global_work_buffer,
       LocalWelfordTripletTuple<NumArgs, DataType, IndexType>(
           init_avg, init_var, init_N),
+      block_dim,
       shared_buf,
       block_red_idx_offset,
       num_thread_iters,
@@ -465,7 +503,7 @@ __device__ __inline__ void ParallelReduce<
       grid_reduce_participate);
 
   // Forward protect the smem buffer
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 template <
@@ -477,7 +515,12 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, int NumArgs, typename DataType, typename IndexType>
+template <
+    bool Aligned,
+    int NumArgs,
+    typename DataType,
+    typename IndexType,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -497,6 +540,10 @@ __device__ __inline__ void ParallelReduce<
         const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg,
         const typename MakeLocalTuple<NumArgs, DataType>::type& init_var,
         const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         typename MakeVolatilePtrTuple<NumArgs, DataType>::type
             global_work_buffer_avg,
         typename MakeVolatilePtrTuple<NumArgs, DataType>::type
@@ -526,6 +573,7 @@ __device__ __inline__ void ParallelReduce<
       init_avg,
       init_var,
       init_N,
+      block_dim,
       global_work_buffer_avg,
       global_work_buffer_var,
       global_work_buffer_N,
@@ -555,7 +603,8 @@ template <
     bool Aligned,
     int NumVals,
     typename DataType,
-    typename IndexType>
+    typename IndexType,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -568,6 +617,10 @@ __device__ __inline__ void ParallelReduce<
     welfordGroupBlock(
         LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result,
         const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         PtrTuple<DataType, DataType, IndexType> shared_buf,
         const typename MakeLocalTuple<NumVals, bool>::type& read_preds,
         bool block_reduce_participate) {
@@ -582,13 +635,13 @@ __device__ __inline__ void ParallelReduce<
   // to number of threads
   const int block_reduction_size = index_utils::
       maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
-          blockDim);
+          block_dim);
 
   // Index in the reduction segment, can be an int since it's limited to
   // number of threads
   const int tid_in_block_reduction = index_utils::
       maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // ID of the block reduction this thread is participating in
   //
@@ -598,7 +651,7 @@ __device__ __inline__ void ParallelReduce<
   // dimension
   const int block_reduction_idx = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Do not protect the smem buffer as it's not always necessary.
   impl::blockWelfordEach<
@@ -607,7 +660,8 @@ __device__ __inline__ void ParallelReduce<
       Aligned,
       NumVals,
       DataType,
-      IndexType>(
+      IndexType,
+      BlockDimT>(
       block_result,
       block_result,
       shared_buf,
@@ -615,7 +669,8 @@ __device__ __inline__ void ParallelReduce<
       tid_in_block_reduction,
       block_reduction_size,
       block_reduction_size,
-      block_reduction_idx);
+      block_reduction_idx,
+      block_dim);
 }
 
 template <
@@ -627,7 +682,12 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, int NumVals, typename DataType, typename IndexType>
+template <
+    bool Aligned,
+    int NumVals,
+    typename DataType,
+    typename IndexType,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -642,6 +702,10 @@ __device__ __inline__ void ParallelReduce<
         const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>&
             global_work_buffer,
         const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         PtrTuple<DataType, DataType, IndexType> shared_buf,
         nvfuser_index_t block_red_idx_offset,
         nvfuser_index_t num_thread_iters,
@@ -670,12 +734,12 @@ __device__ __inline__ void ParallelReduce<
     int tid_in_block_reduction = index_utils::maskedOffset<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(threadIdx, blockDim);
+        activeNotIter(Z_THREAD)>(threadIdx, block_dim);
 
     int block_reduction_size = index_utils::maskedSize<
         activeNotIter(X_THREAD),
         activeNotIter(Y_THREAD),
-        activeNotIter(Z_THREAD)>(blockDim);
+        activeNotIter(Z_THREAD)>(block_dim);
 
     bool has_block_result = index_utils::maskedIsZero<
         activeNotIter(X_THREAD),
@@ -700,7 +764,7 @@ __device__ __inline__ void ParallelReduce<
     // Which block reduction this thread is participating in
     int block_reduction_idx = index_utils::
         maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
 
     impl::blockWelfordEach<
         BROADCAST,
@@ -716,7 +780,8 @@ __device__ __inline__ void ParallelReduce<
         tid_in_block_reduction,
         block_reduction_size,
         min(grid_red_size, block_reduction_size),
-        block_reduction_idx);
+        block_reduction_idx,
+        block_dim);
 
     copyWelfordTripletTupleIf(
         out,
diff --git a/runtime/fused_welford_impl_outer.cu b/runtime/fused_welford_impl_outer.cu
index bd705bbec71..4d314c3bbac 100644
--- a/runtime/fused_welford_impl_outer.cu
+++ b/runtime/fused_welford_impl_outer.cu
@@ -212,7 +212,13 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY>
+template <
+    bool Aligned,
+    int NumVals,
+    typename DataType,
+    int BDIMX,
+    int BDIMY,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -229,6 +235,10 @@ __device__ __inline__ void ParallelReduce<
         const DataType in_avg[NumVals],
         const DataType in_var[NumVals],
         nvfuser_index_t in_N,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         DataType* global_buf_avg,
         DataType* global_buf_var,
         nvfuser_index_t* global_buf_N,
@@ -258,11 +268,11 @@ __device__ __inline__ void ParallelReduce<
 
   auto iter_tid = index_utils::
       maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   auto per_block_result =
       impl::blockWelfordOuter<Aligned, NumVals, DataType, BDIMX, BDIMY>(
-          out_avg, out_var, in_N, shared_buf);
+          out_avg, out_var, in_N, block_dim, shared_buf);
 
   // At this point, threads with tid_in_group == 0 has valid partial
   // results. Store them to global buffer.
@@ -310,7 +320,8 @@ __device__ __inline__ void ParallelReduce<
       isReduce(Y_BLOCK),
       isReduce(Z_BLOCK),
       PERSISTENT_REDUCTION,
-      Aligned>(global_sync_buffer[blockIdx.x], gridDim.y, last_block);
+      Aligned>(
+      global_sync_buffer[blockIdx.x], gridDim.y, last_block, block_dim);
 
   auto partial_results =
       welfordGroupAccumulateGlobalBuffer<NumVals, DataType, BDIMX, BDIMY>(
@@ -321,6 +332,7 @@ __device__ __inline__ void ParallelReduce<
           partial_results.avg_.array,
           partial_results.var_.array,
           partial_results.N_,
+          block_dim,
           shared_buf);
 
   // At this point, each thread of the groups with tid_in_group=0
@@ -361,7 +373,13 @@ template <
     int Z_THREAD,
     bool PERSISTENT_REDUCTION,
     bool BROADCAST>
-template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY>
+template <
+    bool Aligned,
+    int NumVals,
+    typename DataType,
+    int BDIMX,
+    int BDIMY,
+    typename BlockDimT>
 __device__ __inline__ void ParallelReduce<
     X_BLOCK,
     Y_BLOCK,
@@ -378,6 +396,10 @@ __device__ __inline__ void ParallelReduce<
         const DataType in_avg[NumVals],
         const DataType in_var[NumVals],
         nvfuser_index_t in_N,
+        // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+        // there is no warp specialization in the kernel. If there is warp
+        // specialization, block_dim is the the dimension of the compute warps.
+        BlockDimT block_dim,
         DataType* global_buf_avg,
         DataType* global_buf_var,
         nvfuser_index_t* global_buf_N,
@@ -399,6 +421,7 @@ __device__ __inline__ void ParallelReduce<
       in_avg,
       in_var,
       in_N,
+      block_dim,
       global_buf_avg,
       global_buf_var,
       global_buf_N,
diff --git a/runtime/grid_broadcast.cu b/runtime/grid_broadcast.cu
index 5f3db3ffd4a..8beb4055f7c 100644
--- a/runtime/grid_broadcast.cu
+++ b/runtime/grid_broadcast.cu
@@ -28,13 +28,18 @@ template <
     bool Y_THREAD,
     bool Z_THREAD,
     bool Aligned,
-    typename T>
+    typename T,
+    typename BlockDimT>
 __device__ void broadcast(
     T& out,
     const T& inp_val,
     volatile T* work_buf,
     Tensor<int64_t, 1> sync_flags,
-    bool read_write_pred) {
+    bool read_write_pred,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // Number of values broadcasted in the grid dimensions
   const auto grid_seg_size =
       index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
@@ -47,13 +52,13 @@ __device__ void broadcast(
   // Number of threads not participating in a broadcast dimension, this is the
   // number of thread entries to expect in the work buffer, therefore a striding
   const auto block_stride =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   // Which broadcast in the block this is to line up the entry with the work
   // buffer
   const auto thread_offset =
       index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   const bool has_valid_data = (!X_BLOCK || blockIdx.x == gridDim.x - 1) &&
       (!Y_BLOCK || blockIdx.y == gridDim.y - 1) &&
@@ -67,7 +72,7 @@ __device__ void broadcast(
   }
 
   grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true, Aligned>(
-      sync_flags[grid_seg_idx], grid_seg_size);
+      sync_flags[grid_seg_idx], grid_seg_size, block_dim);
 
   if (read_write_pred) {
     out = work_buf[grid_seg_idx * block_stride + thread_offset];
@@ -76,6 +81,6 @@ __device__ void broadcast(
   // Make sure everyone has read from the buffer before continuing the kernel
   // and potentially overwriting
   grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true, Aligned>(
-      sync_flags[grid_seg_idx], grid_seg_size);
+      sync_flags[grid_seg_idx], grid_seg_size, block_dim);
 }
 } // namespace grid_broadcast
diff --git a/runtime/grid_reduction.cu b/runtime/grid_reduction.cu
index c4a8638910f..3a81432e27a 100644
--- a/runtime/grid_reduction.cu
+++ b/runtime/grid_reduction.cu
@@ -75,7 +75,8 @@ template <
     bool Z_THREAD,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void gridReduceLastBlock(
     T& out,
     const volatile T* in,
@@ -87,7 +88,11 @@ __device__ void gridReduceLastBlock(
     Func reduction_op,
     T* shared_buf,
     bool write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // We have to do num_reductions across reduction_size. The reductions are
   // contiguous, but offset by reduction_size. There is an entry in "in" for
   // every block, and every thread marked as true. Threads in dimensions marked
@@ -96,18 +101,18 @@ __device__ void gridReduceLastBlock(
   // Find the reduction id of the participating threads
   const auto block_reduction_segment_idx =
       index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Find an id associated within a reduction segment for all
   // "non-participating" threads, which will parallelize the reductions for the
   // "participating" threads
   const auto id_in_block_segment =
       index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Stride by the "non-participating" threads
   const auto input_stride_for_thread_in_segment =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   T inp = init_val;
 
@@ -123,7 +128,7 @@ __device__ void gridReduceLastBlock(
   // Block reduce the per thread values into per "participating" thread values
   T inp_tmp = init_val;
   blockReduce<!X_THREAD, !Y_THREAD, !Z_THREAD, Aligned>(
-      inp_tmp, inp, reduction_op, shared_buf, true, init_val);
+      inp_tmp, inp, reduction_op, shared_buf, true, init_val, block_dim);
   const bool should_write = (X_THREAD || threadIdx.x == 0) &&
       (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
   if (should_write && write_pred) {
@@ -191,7 +196,8 @@ template <
     bool PERSISTENT_REDUCTION,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void gridReduce(
     T& out,
     const T& inp_val,
@@ -203,7 +209,11 @@ __device__ void gridReduce(
     bool write_pred,
     T init_val,
     const nvfuser_index_t entrance_ind,
-    const nvfuser_index_t n_entrances) {
+    const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   T block_reduction_val = init_val;
 
   // Do block reduction when required
@@ -215,7 +225,8 @@ __device__ void gridReduce(
         shared_buf,
         read_pred,
         true,
-        init_val);
+        init_val,
+        block_dim);
   } else if (read_pred) {
     block_reduction_val = inp_val;
   }
@@ -233,7 +244,7 @@ __device__ void gridReduce(
   // Number of threads we can use in final reduction, Seems to assume all
   // threads in the block participate
   const auto block_reduction_segment_size =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   // Number of reductions in the grid
   const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
@@ -251,20 +262,23 @@ __device__ void gridReduce(
         index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
     auto thread_offset =
         index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
     auto work_buf_offset =
         block_offset * block_reduction_segment_size + thread_offset;
     work_buf[work_buf_offset] = block_reduction_val;
   }
   if (PERSISTENT_REDUCTION) {
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
 
   } else {
     // Use a different sync flag for each call
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
         sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
-        grid_reduction_segment_size);
+        grid_reduction_segment_size,
+        block_dim);
   }
 
   bool last_block =
@@ -280,14 +294,17 @@ __device__ void gridReduce(
         reduction_op,
         shared_buf,
         write_pred,
-        init_val);
+        init_val,
+        block_dim);
   }
 
   if (PERSISTENT_REDUCTION) {
     // Make sure we're done with global memory before we allow the kernel to
     // continue
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   }
 }
 
@@ -306,7 +323,8 @@ template <
     bool PERSISTENT_REDUCTION,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void gridReduce(
     T& out,
     const T& inp_val,
@@ -319,6 +337,10 @@ __device__ void gridReduce(
     T init_val,
     const nvfuser_index_t entrance_ind,
     const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     int64_t& cycles,
     int64_t& count) {
   int64_t start_counter = 0;
@@ -349,7 +371,8 @@ __device__ void gridReduce(
       write_pred,
       init_val,
       entrance_ind,
-      n_entrances);
+      n_entrances,
+      block_dim);
 
   if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
       index_utils::maskedIsZero<true, true, true>(threadIdx)) {
@@ -368,11 +391,16 @@ template <
     bool Z_THREAD,
     bool Aligned,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void gridReduce2PartialReduction(
     const T& inp_val,
     T init_val,
     Func reduction_op,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     volatile T* work_buf,
     T* shared_buf,
     bool read_pred,
@@ -390,7 +418,8 @@ __device__ void gridReduce2PartialReduction(
         shared_buf,
         read_pred,
         true,
-        init_val);
+        init_val,
+        block_dim);
   } else if (read_pred) {
     block_reduction_val = inp_val;
   }
@@ -401,7 +430,7 @@ __device__ void gridReduce2PartialReduction(
         index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
     auto thread_offset =
         index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
     auto work_buf_offset =
         block_offset * block_reduction_segment_size + thread_offset;
     work_buf[work_buf_offset] = block_reduction_val;
@@ -421,7 +450,8 @@ template <
     typename T1,
     typename Func1,
     typename T2,
-    typename Func2>
+    typename Func2,
+    typename BlockDimT>
 __device__ void gridReduceGroup(
     T1& out1,
     const T1& inp_val1,
@@ -438,7 +468,11 @@ __device__ void gridReduceGroup(
     bool read_pred,
     bool write_pred,
     const nvfuser_index_t entrance_ind,
-    const nvfuser_index_t n_entrances) {
+    const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // Number of values to reduce in the reduction segment
   const auto grid_reduction_segment_size =
       index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim);
@@ -452,7 +486,7 @@ __device__ void gridReduceGroup(
   // Number of threads we can use in final reduction, Seems to assume all
   // threads in the block participate
   const auto block_reduction_segment_size =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   // Number of reductions in the grid
   const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
@@ -478,6 +512,7 @@ __device__ void gridReduceGroup(
       inp_val1,
       init_val1,
       reduction_op1,
+      block_dim,
       work_buf1,
       (T1*)shared_buf,
       read_pred,
@@ -496,6 +531,7 @@ __device__ void gridReduceGroup(
       inp_val2,
       init_val2,
       reduction_op2,
+      block_dim,
       work_buf2,
       (T2*)shared_buf,
       read_pred,
@@ -505,11 +541,14 @@ __device__ void gridReduceGroup(
 
   if (PERSISTENT_REDUCTION) {
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   } else {
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
         sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment],
-        grid_reduction_segment_size);
+        grid_reduction_segment_size,
+        block_dim);
   }
 
   bool last_block =
@@ -525,7 +564,8 @@ __device__ void gridReduceGroup(
         reduction_op1,
         (T1*)shared_buf,
         write_pred,
-        init_val1);
+        init_val1,
+        block_dim);
     gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD, Aligned>(
         out2,
         work_buf2,
@@ -534,14 +574,17 @@ __device__ void gridReduceGroup(
         reduction_op2,
         (T2*)shared_buf,
         write_pred,
-        init_val2);
+        init_val2,
+        block_dim);
   }
 
   if (PERSISTENT_REDUCTION) {
     // Make sure we're done with global memory before we allow the kernel to
     // continue
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   }
 }
 
@@ -558,7 +601,8 @@ template <
     typename T1,
     typename Func1,
     typename T2,
-    typename Func2>
+    typename Func2,
+    typename BlockDimT>
 __device__ void gridReduceGroup(
     T1& out1,
     const T1& inp_val1,
@@ -576,6 +620,10 @@ __device__ void gridReduceGroup(
     bool write_pred,
     const nvfuser_index_t entrance_ind,
     const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     int64_t& cycles,
     int64_t& count) {
   int64_t start_counter = 0;
@@ -613,7 +661,8 @@ __device__ void gridReduceGroup(
       read_pred,
       write_pred,
       entrance_ind,
-      n_entrances);
+      n_entrances,
+      block_dim);
 
   if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) &&
       index_utils::maskedIsZero<true, true, true>(threadIdx)) {
@@ -705,7 +754,8 @@ template <
     bool Aligned,
     int vec_size,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void iterGroupedGridReduceLastBlock(
     T* out,
     const volatile T* in,
@@ -719,7 +769,11 @@ __device__ void iterGroupedGridReduceLastBlock(
     bool write_pred,
     T init_val,
     const nvfuser_index_t grid_segment_size,
-    const nvfuser_index_t idx_in_grid_segment) {
+    const nvfuser_index_t idx_in_grid_segment,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // We have to do num_reductions across reduction_size. The reductions are
   // contiguous, but offset by reduction_size. There is an entry in "in" for
   // every block, and every thread marked as true. Threads in dimensions marked
@@ -728,14 +782,14 @@ __device__ void iterGroupedGridReduceLastBlock(
   // Find the reduction id of the participating threads
   const auto block_reduction_segment_idx =
       index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Find an id associated within a reduction segment for all
   // "non-participating" threads, which will parallelize the reductions for the
   // "participating" threads
   const auto id_in_block_segment =
       index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // index into iteration dim.
   // Its calculation is same to that in [iterGroupedGridReduce]. Becuase when
@@ -743,11 +797,11 @@ __device__ void iterGroupedGridReduceLastBlock(
   // X_THREAD, Y_THREAD, Z_THREAD are flipped.
   const auto thread_offset =
       index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Stride by the "non-participating" threads
   const auto input_stride_for_thread_in_segment =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   constexpr unsigned int max_align_bytes = 16;
   constexpr unsigned int vec_bytes = sizeof(T) * vec_size;
@@ -814,7 +868,7 @@ __device__ void iterGroupedGridReduceLastBlock(
     inp_tmp[i] = init_val;
   }
   blockIterGroupedYdimReduce<Aligned, vec_size>(
-      inp_tmp, inp, reduction_op, shared_buf, true, init_val);
+      inp_tmp, inp, reduction_op, shared_buf, true, init_val, block_dim);
   const bool should_write = (X_THREAD || threadIdx.x == 0) &&
       (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
   if (should_write && write_pred) {
@@ -846,7 +900,8 @@ template <
     bool Aligned,
     int vec_size,
     typename T,
-    typename Func>
+    typename Func,
+    typename BlockDimT>
 __device__ void iterGroupedGridReduce(
     T* out,
     const T* inp_val,
@@ -856,7 +911,11 @@ __device__ void iterGroupedGridReduce(
     T* shared_buf,
     bool read_pred,
     bool write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // inp or block reduction results
   T block_reduction_val[vec_size];
 
@@ -873,7 +932,8 @@ __device__ void iterGroupedGridReduce(
         shared_buf,
         read_pred,
         true,
-        init_val);
+        init_val,
+        block_dim);
   } else if (read_pred) {
 #pragma unroll
     for (int i = 0; i < vec_size; i++) {
@@ -893,7 +953,7 @@ __device__ void iterGroupedGridReduce(
 
   // Number of reductions in each block
   const auto block_segment_size =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   // Number of reductions in the grid
   const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
@@ -908,7 +968,7 @@ __device__ void iterGroupedGridReduce(
         index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
     auto thread_offset =
         index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
 
     // Max vectorized load/store size is 16 bytes, if each thread has more than
     // 16 bytes, split into multiple sections to ensure each thread occupies
@@ -959,12 +1019,16 @@ __device__ void iterGroupedGridReduce(
 
   if (PERSISTENT_REDUCTION) {
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
 
   } else {
     // there is only one vectorized call
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   }
 
   bool last_block =
@@ -987,14 +1051,17 @@ __device__ void iterGroupedGridReduce(
         write_pred,
         init_val,
         grid_segment_size,
-        idx_in_grid_segment);
+        idx_in_grid_segment,
+        block_dim);
   }
 
   if (PERSISTENT_REDUCTION) {
     // Make sure we're done with global memory before we allow the kernel to
     // continue
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   }
 }
 } // namespace reduction
diff --git a/runtime/grid_sync.cu b/runtime/grid_sync.cu
index 4e5289323b1..ba7c115a4d0 100644
--- a/runtime/grid_sync.cu
+++ b/runtime/grid_sync.cu
@@ -29,16 +29,21 @@ template <
     bool Y_BLOCK,
     bool Z_BLOCK,
     bool PERSISTENT,
-    bool Aligned>
+    bool Aligned,
+    typename BlockDimT>
 __device__ void sync(
     int64_t& semaphore,
     const uint64_t& segment_size,
-    const bool last_block) {
+    const bool last_block,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // Finish all global memory transactions before synchronizing
   __threadfence();
 
   // Synchronize all threads in a block before synchronizing blocks
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Only allow linear_tid == 0 to participate in the synchronization
   if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
@@ -78,7 +83,7 @@ __device__ void sync(
   }
 
   // Sync block to make sure all other threads are waiting on the sync
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 template <
@@ -86,12 +91,20 @@ template <
     bool Y_BLOCK,
     bool Z_BLOCK,
     bool PERSISTENT,
-    bool Aligned>
-__device__ void sync(int64_t& semaphore, const uint64_t& segment_size) {
+    bool Aligned,
+    typename BlockDimT>
+__device__ void sync(
+    int64_t& semaphore,
+    const uint64_t& segment_size,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT, Aligned>(
       semaphore,
       segment_size,
-      index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim));
+      index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim),
+      block_dim);
 }
 
 // Grid sync that can be called multiple times in the same kernel without all
@@ -105,16 +118,25 @@ __device__ void sync(int64_t& semaphore, const uint64_t& segment_size) {
 //
 // Note that this is not currently used by grid and welford reduction
 // as they use a separate sync flag for each each grid sync call.
-template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, bool Aligned>
+template <
+    bool X_BLOCK,
+    bool Y_BLOCK,
+    bool Z_BLOCK,
+    bool Aligned,
+    typename BlockDimT>
 __device__ void sync(
     int64_t& semaphore,
     const uint64_t& segment_size,
-    const nvfuser_index_t n_entrances) {
+    const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // Finish all global memory transactions before synchronizing
   __threadfence();
 
   // Synchronize all threads in a block before synchronizing blocks
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // Only allow linear_tid == 0 to participate in the synchronization
   if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
@@ -147,7 +169,7 @@ __device__ void sync(
   }
 
   // Sync block to make sure all other threads are waiting on the sync
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 // Non-blocking function to read the semaphore value in each calling thread
diff --git a/runtime/warp.cu b/runtime/warp.cu
index 03d78e74798..5678e4d05b0 100644
--- a/runtime/warp.cu
+++ b/runtime/warp.cu
@@ -21,14 +21,23 @@ __device__ __forceinline__ std::complex<T> shfl_xor(
   return std::complex<T>(real, imag);
 }
 
-template <bool SINGLE_WARP, bool Aligned, typename T, typename Func>
+template <
+    bool SINGLE_WARP,
+    bool Aligned,
+    typename T,
+    typename Func,
+    typename BlockDimT>
 __device__ void warpReduceTIDX(
     T& out,
     const T& inp_val,
     Func reduction_op,
     T* shared_mem,
     bool read_write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   constexpr int WARP_SIZE = 32;
 
   // Assume input padded to multiples of a warp
@@ -49,19 +58,19 @@ __device__ void warpReduceTIDX(
   if (!SINGLE_WARP) {
     unsigned int warp_idx = threadIdx.x / WARP_SIZE;
     unsigned int lane_idx = threadIdx.x % WARP_SIZE;
-    unsigned int reduce_group_id = threadIdx.z * blockDim.y + threadIdx.y;
+    unsigned int reduce_group_id = threadIdx.z * block_dim.y + threadIdx.y;
     bool is_warp_head = lane_idx == 0;
-    unsigned int reduction_size = blockDim.x;
+    unsigned int reduction_size = block_dim.x;
     unsigned int num_of_warps = reduction_size / WARP_SIZE;
     unsigned int smem_offset = reduce_group_id * num_of_warps;
 
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     if (is_warp_head) {
       shared_mem[smem_offset + warp_idx] = reduce_val;
     }
 
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     if (warp_idx == 0) {
       // This assumes num_of_warps will be < 32, meaning < 1024 threads.
@@ -82,20 +91,30 @@ __device__ void warpReduceTIDX(
     }
     // needs sync, otherwise other warps may access shared memory before this
     // reduction is done.
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   } else {
     reduction_op(out, reduce_val);
   }
 }
 
-template <int BDIMX, int BDIMY, bool Aligned, typename T, typename Func>
+template <
+    int BDIMX,
+    int BDIMY,
+    bool Aligned,
+    typename T,
+    typename Func,
+    typename BlockDimT>
 __device__ void warpReduceTIDXY(
     T& out,
     const T& inp_val,
     Func reduction_op,
     T* shared_mem,
     bool read_write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   constexpr int WARP_SIZE = 32;
   constexpr int num_of_warps = BDIMX * BDIMY / WARP_SIZE;
 
@@ -118,11 +137,11 @@ __device__ void warpReduceTIDXY(
     unsigned int idx = threadIdx.x + threadIdx.y * BDIMX;
     unsigned int warp_idx = idx / WARP_SIZE;
     unsigned int lane_idx = idx % WARP_SIZE;
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
     if (lane_idx == 0) {
       shared_mem[warp_idx] = reduce_val;
     }
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
 
     if (warp_idx == 0) {
       reduce_val = lane_idx < num_of_warps ? shared_mem[lane_idx] : init_val;
@@ -137,7 +156,7 @@ __device__ void warpReduceTIDXY(
     }
     // needs sync, otherwise other warps may access shared memory before this
     // reduction is done.
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   } else {
     reduction_op(out, reduce_val);
   }
diff --git a/runtime/welford.cu b/runtime/welford.cu
index aba379bffcb..35c5f69109b 100644
--- a/runtime/welford.cu
+++ b/runtime/welford.cu
@@ -114,7 +114,8 @@ template <
     bool Z_REDUCE,
     bool Aligned,
     typename T,
-    typename TN>
+    typename TN,
+    typename BlockDimT>
 __inline__ __device__ void blockWelford(
     T& out_avg,
     T& out_M2,
@@ -127,24 +128,28 @@ __inline__ __device__ void blockWelford(
     TN* shared_mem_N,
     bool read_pred,
     bool write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // If this thread will output a final result
   bool should_write =
       index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(threadIdx);
 
   // Size of the reduction segments
   unsigned int reduction_size =
-      index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(blockDim);
+      index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim);
 
   // Index into the reduction segment
   unsigned int reduction_tid =
       index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Index of the reduction segment
   unsigned int reduction_idx =
       index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Offset into smem for the current thread
   unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid;
@@ -159,7 +164,7 @@ __inline__ __device__ void blockWelford(
     shared_mem_N[smem_offset] = 0;
   }
 
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
   // Reduce down to nearest power of 2:
   int np2 = 1 << (31 - __clz(reduction_size));
 
@@ -172,7 +177,7 @@ __inline__ __device__ void blockWelford(
         shared_mem_M2[smem_offset + np2],
         shared_mem_N[smem_offset + np2]);
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 
   // loop peel the final iteration to save one syncthread for the end
   for (int factor = np2 / 2; factor > 1; factor >>= 1) {
@@ -185,7 +190,7 @@ __inline__ __device__ void blockWelford(
           shared_mem_M2[smem_offset + factor],
           shared_mem_N[smem_offset + factor]);
     }
-    block_sync::sync<Aligned>();
+    block_sync::sync<Aligned>(block_dim);
   }
 
   if (should_write && write_pred) {
@@ -212,7 +217,7 @@ __inline__ __device__ void blockWelford(
     out_M2 = res_M2;
     out_N = res_N;
   }
-  block_sync::sync<Aligned>();
+  block_sync::sync<Aligned>(block_dim);
 }
 
 // Use the same pred for both reads and writes
@@ -222,7 +227,8 @@ template <
     bool Z_REDUCE,
     bool Aligned,
     typename T,
-    typename TN>
+    typename TN,
+    typename BlockDimT>
 __inline__ __device__ void blockWelford(
     T& out_avg,
     T& out_M2,
@@ -234,7 +240,11 @@ __inline__ __device__ void blockWelford(
     T* shared_mem_M2,
     TN* shared_mem_N,
     bool read_write_pred,
-    T init_val) {
+    T init_val,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, Aligned, T, TN>(
       out_avg,
       out_M2,
@@ -247,7 +257,8 @@ __inline__ __device__ void blockWelford(
       shared_mem_N,
       read_write_pred,
       read_write_pred,
-      init_val);
+      init_val,
+      block_dim);
 }
 // -----------------------------------------------------------------------------------------------
 //  Grid Welford Prototype
@@ -260,7 +271,8 @@ template <
     bool Z_THREAD,
     bool Aligned,
     typename T,
-    typename TN>
+    typename TN,
+    typename BlockDimT>
 __device__ void gridWelfordLastBlock(
     T& out_avg,
     T& out_M2,
@@ -271,8 +283,11 @@ __device__ void gridWelfordLastBlock(
     const nvfuser_index_t
         grid_reduction_segment_size, // Number of reductions across
                                      // grid reduce dimensions
-    const nvfuser_index_t
-        block_reduction_segment_size, // Number of reductions across the block
+    const nvfuser_index_t block_reduction_segment_size,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim,
     T* shared_buf_avg,
     T* shared_buf_M2,
     TN* shared_buf_N,
@@ -286,18 +301,18 @@ __device__ void gridWelfordLastBlock(
   // Find the reduction id of the participating threads
   const auto block_reduction_segment_idx =
       index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Find an id associated within a reduction segment for all
   // "non-participating" threads, which will parallelize the reductions for the
   // "participating" threads
   const auto id_in_block_segment =
       index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>(
-          threadIdx, blockDim);
+          threadIdx, block_dim);
 
   // Stride by the "non-participating" threads
   const auto input_stride_for_thread_in_segment =
-      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim);
+      index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim);
 
   T inp_avg = init_val;
   T inp_M2 = init_val;
@@ -333,7 +348,8 @@ __device__ void gridWelfordLastBlock(
       shared_buf_M2,
       shared_buf_N,
       true,
-      init_val);
+      init_val,
+      block_dim);
   const bool should_write = (X_THREAD || threadIdx.x == 0) &&
       (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0);
   if (should_write && write_pred) {
@@ -352,7 +368,8 @@ template <
     bool PERSISTENT_REDUCTION,
     bool Aligned,
     typename T,
-    typename TN>
+    typename TN,
+    typename BlockDimT>
 __device__ void gridWelford(
     T& out_avg,
     T& out_M2,
@@ -371,7 +388,11 @@ __device__ void gridWelford(
     bool write_pred,
     T init_val,
     const nvfuser_index_t entrance_ind,
-    const nvfuser_index_t n_entrances) {
+    const nvfuser_index_t n_entrances,
+    // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if
+    // there is no warp specialization in the kernel. If there is warp
+    // specialization, block_dim is the the dimension of the compute warps.
+    BlockDimT block_dim) {
   // entrance index only matters for non-persistent re-entrant grid reductions.
   const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind;
   const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances;
@@ -389,7 +410,7 @@ __device__ void gridWelford(
   // Number of threads we can use in final reduction, Seems to assume all
   // threads in the block participate
   const auto block_reduction_segment_size =
-      index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(blockDim);
+      index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(block_dim);
 
   // Number of reductions in the grid
   const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION
@@ -411,7 +432,7 @@ __device__ void gridWelford(
         index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim);
     auto thread_offset =
         index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>(
-            threadIdx, blockDim);
+            threadIdx, block_dim);
     auto work_buf_offset =
         block_offset * block_reduction_segment_size + thread_offset;
     if (read_pred) {
@@ -427,12 +448,15 @@ __device__ void gridWelford(
 
   if (PERSISTENT_REDUCTION) {
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   } else {
     // Use a different sync flag for each call
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
         sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment],
-        grid_reduction_segment_size);
+        grid_reduction_segment_size,
+        block_dim);
   }
 
   bool last_block =
@@ -449,6 +473,7 @@ __device__ void gridWelford(
         work_buf_N,
         grid_reduction_segment_size,
         block_reduction_segment_size,
+        block_dim,
         shared_buf_avg,
         shared_buf_M2,
         shared_buf_N,
@@ -460,7 +485,9 @@ __device__ void gridWelford(
     // Make sure we're done with global memory before we allow the kernel to
     // continue
     grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>(
-        sync_flags[idx_in_grid_segment], grid_reduction_segment_size);
+        sync_flags[idx_in_grid_segment],
+        grid_reduction_segment_size,
+        block_dim);
   }
 }
 
diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp
index 725f2c128b9..a5eb595af9c 100644
--- a/tests/cpp/test_circular_buffering.cpp
+++ b/tests/cpp/test_circular_buffering.cpp
@@ -1366,12 +1366,6 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) {
 TEST_P(TmaCircularBufferingTest, InnerReduction) {
   NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
 
-  if (std::holds_alternative<WarpSpecialized>(circular_buffer_type)) {
-    GTEST_SKIP()
-        << "This test uses block reduce, which uses hard-coded blockDim, "
-        << "which can cause deadlock when combined with warp specialization.";
-  }
-
   std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>();
   FusionGuard fg(fusion.get());
 
@@ -1486,13 +1480,6 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) {
 TEST_P(TmaCircularBufferingTest, Persistent) {
   NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
 
-  if (std::holds_alternative<WarpSpecialized>(circular_buffer_type)) {
-    GTEST_SKIP()
-        << "This test uses block reduce and block broadcast, "
-        << "which has hard-coded blockDim, "
-        << "which can cause deadlock when combined with warp specialization.";
-  }
-
   constexpr at::ScalarType dtype = at::ScalarType::Float;
   constexpr int64_t correction = 0;
   constexpr int64_t reduction_axis = 1;
diff --git a/tests/cpp/test_gpu2.cpp b/tests/cpp/test_gpu2.cpp
index 08f7f631cca..c6df7767aca 100644
--- a/tests/cpp/test_gpu2.cpp
+++ b/tests/cpp/test_gpu2.cpp
@@ -2086,7 +2086,8 @@ __global__ void kernel1(
         (float*)mem_M2,
         (long*)mem_N,
         (bool)(threadIdx.x<inp.logical_size[0]),
-        0.f);
+        0.f,
+        blockDim);
     __syncthreads();
     if(threadIdx.x<out_var.logical_size[0] && threadIdx.y==0){
         welfordCombine(
@@ -2175,7 +2176,8 @@ __global__ void kernel1(
         (float*)mem_M2,
         (long*)mem_N,
         (bool)(threadIdx.x<inp.logical_size[0]),
-        0.f);
+        0.f,
+        blockDim);
     __syncthreads();
     if(threadIdx.x<out_var.logical_size[0] && threadIdx.y==0 && threadIdx.z==0){
         out_avg[threadIdx.x*out_var.alloc_stride[0]]=tmp_avg;
@@ -2251,7 +2253,8 @@ __global__ void kernel1(
         threadIdx.x<out_var.logical_size[0],
         0.f,
         0,
-        1);
+        1,
+        blockDim);
     if(blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1){
         out_avg[threadIdx.x*out_avg.alloc_stride[0]]=tmp_avg;
         out_var[threadIdx.x*out_var.alloc_stride[0]]=tmp_M2/tmp_N;

From 5a2184ca920bc36a808bebc4455ab7161c1ed5a0 Mon Sep 17 00:00:00 2001
From: Ryan Spring <rspring@nvidia.com>
Date: Mon, 9 Dec 2024 14:26:09 -0800
Subject: [PATCH 2/3] Add support for 32B and 64B swizzles to hopper matmul
 scheduler (#3544)

This PR adds support for 32B and 64B swizzles to StMatrix indexing and
to the hopper matmul scheduler.

### Key Index Change
The number of distinct swizzle rows is number of bytes for swizzle
divided by size of megabank (16B). The number of times a swizzle pattern
is repeated to fill core (8, 8) matrix is number of swizzle rows (8)
divided by number of distinct
rows.

```cpp
MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);
constexpr int64_t megabank_size_bytes = 16;
const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;

int row = ...;
int col = ...;
constexpr int64_t swizzle_row_size = 8;
const int64_t swizzle_row_repetitions = swizzle_row_size / distinct_swizzle_row_size;
int64_t  row_in_swizzle_pattern = (row % swizzle_row_size) / swizzle_row_repetitions;
int64_t swizzle_col = col ^ row_in_swizzle_pattern;
```

### Testing Changes
* Added `mma_macro` as testing value.
* Created separate test suite called `Swizzle/HopperMatmulSchedulerTest`
to test `32B`, `64B`, `128B` swizzles.
---
 csrc/device_lower/pass/index.cpp       | 66 ++++++++++++++++++--------
 csrc/scheduler/hopper_multi_matmul.cpp | 18 +++----
 csrc/scheduler/hopper_multi_matmul.h   |  7 ++-
 tests/cpp/test_matmul_scheduler.cpp    | 61 ++++++++++++++++++++----
 tests/cpp/utils.h                      | 34 +++++++++++++
 5 files changed, 144 insertions(+), 42 deletions(-)

diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp
index 7f1dc742917..74632876c8c 100644
--- a/csrc/device_lower/pass/index.cpp
+++ b/csrc/device_lower/pass/index.cpp
@@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
 }
 
 static DataType getMmaInputAType(MmaMacro macro) {
-  int warp_group_size = isHopper(macro) ? 128 : 32;
-  int size = getM(macro) * getK(macro) / warp_group_size /
-      2 /* halves per 32bit register */;
+  int64_t warp_group_size = isHopper(macro) ? 128L : 32L;
+  int64_t size = getM(macro) * getK(macro) / warp_group_size /
+      2L /* halves per 32bit register */;
   return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
 }
 
 static DataType getMmaInputBType(MmaMacro macro) {
-  int size = getN(macro) * getK(macro) / 32 /* threads per warp */ /
-      2 /* halves per 32bit register */;
+  int64_t size = getN(macro) * getK(macro) / 32L /* threads per warp */ /
+      2L /* halves per 32bit register */;
   return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
 }
 
@@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix(
 // To account for the threadIdx.y, we have to add it to the offset:
 //   offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half)
 //
-// Now, lets apply stmatrix tile to the TMA Box.
-// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
+// Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)].
+//   [NO(2), MO(4), MI(16), NIO(4), NII(16)].
 //
 // A warp group of 128 threads contains four warps. StMatrix is a warp-level
 // operation, so four StMatrix operations can be issued simultaneously by the
@@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix(
 // domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the
 // data in shared memory in [M(64), NI(64)] contiguous tiles.
 //
+// NOTE: This offset is skipped if for-loop is trivial
 // To account for the outer_index, we have to add it to the offset:
 //   offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half)
 //
@@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix(
 // with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern
 // is repeated along the rows of the TMA box.
 //
+// The number of distinct swizzle rows is number of bytes for swizzle divided by
+// size of megabank (16B). The number of times a swizzle pattern is repeated to
+// fill core (8, 8) matrix is number of swizzle rows (8) divided by number of
+// distinct rows.
+//
 // Swizzle column
-//   row_in_swizzle_pattern = row % swizzle_row_size(8)
+//   row_in_swizzle_pattern = (row % swizzle_row_size(8)) / swizzle_repetitions
 //   swizzle_col = column XOR row_in_swizzle_pattern
 //
 // Calculate Tile Offset
@@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix(
 //
 // Get shared memory offset
 //   smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset
-Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
+Val* hardCodedIndexGenerationForStMatrixSwizzle(
     const LoadStoreOp* ldst,
     ForLoop* loop,
     const int64_t stsm_m_tile,
@@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
 
   NVF_ERROR(ldst->out()->isA<TensorView>());
   TensorView* out_tv = ldst->out()->as<TensorView>();
-  NVF_ERROR(getSwizzle(out_tv) == MmaInputSmemSwizzle::B128);
+  MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
+  int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);
 
   // Constants
   constexpr int64_t dtype_size = 2;
   constexpr int64_t warp_size = 32;
   constexpr int64_t swizzle_row_size = 8;
   constexpr int64_t stsm_column_size = 8;
-  constexpr int64_t swizzle_n_tile = 64;
+  constexpr int64_t megabank_size_bytes = 16;
 
   // Derived constants
+  const int64_t swizzle_n_tile = swizzle_bytes / dtype_size;
+  const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;
   constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size;
   const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile;
   const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size;
@@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
   Val* warp_id = SimplifyingIrBuilder::divExpr(TDX, warp_size_val);
   Val* lane_id = SimplifyingIrBuilder::modExpr(TDX, warp_size_val);
 
-  Val* outer_index =
-      SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
   Val* inner_index =
       SimplifyingIrBuilder::modExpr(loop->index(), swizzle_n_iter_val);
 
@@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
   // Swizzle Column
   Val* row_in_swizzle_pattern =
       SimplifyingIrBuilder::modExpr(row, swizzle_row_size_val);
+
+  // The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B
+  // swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows
+  // with distict swizzle rows.
+  const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size;
+  if (swizzle_row_iter > 1) {
+    Val* swizzle_row_iter_val =
+        IrBuilder::create<Val>(swizzle_row_iter, DataType::Index);
+    row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr(
+        row_in_swizzle_pattern, swizzle_row_iter_val);
+  }
   Val* swizzle_col = bitwise_xor(col, row_in_swizzle_pattern);
 
   // Calculate Tile Offset
@@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
   Val* offset = SimplifyingIrBuilder::addExpr(row_offset, col_offset);
 
   // Calculate Tile offset
-  Val* tile_offset = IrBuilder::mulExpr(outer_index, tile_stride_val);
+  // Skip tile offset if loop is trivial.
+  if (!loop->stop()->isOneInt()) {
+    Val* outer_index =
+        SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
+    Val* tile_offset =
+        SimplifyingIrBuilder::mulExpr(outer_index, tile_stride_val);
+    offset = SimplifyingIrBuilder::addExpr(tile_offset, offset);
+  }
 
   // Calculate TDY offset
-  Val* tdy_offset = IrBuilder::mulExpr(TDY, tdy_stride_val);
+  Val* tdy_offset = SimplifyingIrBuilder::mulExpr(TDY, tdy_stride_val);
+  offset = SimplifyingIrBuilder::addExpr(tdy_offset, offset);
 
   // Create shared memory TensorIndex
   Val* out_index = SimplifyingIrBuilder::addExpr(
-      IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)),
-      SimplifyingIrBuilder::addExpr(
-          tdy_offset, SimplifyingIrBuilder::addExpr(tile_offset, offset)));
+      IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), offset);
   Val* out = IrBuilder::create<kir::TensorIndex>(
       dynamic_cast<TensorView*>(ldst->out()), out_index);
   return out;
@@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
               ldst, for_loops_[0], m_tile, n_tile, m, n);
           break;
         case MmaInputSmemSwizzle::B128:
-          out = hardCodedIndexGenerationForStMatrix128BSwizzle(
+        case MmaInputSmemSwizzle::B64:
+        case MmaInputSmemSwizzle::B32:
+          out = hardCodedIndexGenerationForStMatrixSwizzle(
               ldst, for_loops_[0], m_tile, n_tile, m, n);
           break;
-        case MmaInputSmemSwizzle::B32:
-        case MmaInputSmemSwizzle::B64:
         default:
           NVF_ERROR("Unsupported Swizzle Type for StMatrix");
       }
diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp
index dad548d3fd0..4320fe80953 100644
--- a/csrc/scheduler/hopper_multi_matmul.cpp
+++ b/csrc/scheduler/hopper_multi_matmul.cpp
@@ -1027,13 +1027,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
     const int64_t tma_m = getM(params_->mma_macro);
     const int64_t tma_n = getN(params_->mma_macro);
 
-    NVF_ERROR(
-        tma_n >= 64,
-        "Scheduler only supports 128B swizzle that requires N dimension of MMA ",
-        "macro to be >= 64, but received ",
-        tma_n,
-        ".");
-
     fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
     fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
     fusion_->manage("st_matrix_m", tma_m);
@@ -1084,12 +1077,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
         dc->setAllocationDomain(s.as<IterDomain*>(), true);
       }
 
+      MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem);
+
       // Schedule shared memory cache; Output from StMatrix
       scheduleStMatrixForMmaOutput(
-          d_smem, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
+          d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
 
       // Schedule global memory output; Output from TMA Store
-      scheduleTMAStoreForMmaOutput(d, tma_m, tma_n);
+      scheduleTMAStoreForMmaOutput(d, swizzle, tma_m, tma_n);
     }
   }
 }
@@ -1247,6 +1242,7 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {
 
 void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
     TensorView* tv,
+    MmaInputSmemSwizzle swizzle,
     int64_t tile_m,
     int64_t tile_n,
     int64_t tma_m,
@@ -1263,7 +1259,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
       mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());
 
   // Create tma store allocation domain with swizzle
-  scheduleTMAStoreForMmaOutput(tv, tma_m, tma_n);
+  scheduleTMAStoreForMmaOutput(tv, swizzle, tma_m, tma_n);
 
   tv->setLoopDomain(s.as<IterDomain*>());
 
@@ -1290,6 +1286,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
 
 void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
     TensorView* tv,
+    MmaInputSmemSwizzle swizzle,
     int64_t m,
     int64_t n) {
   // [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)]
@@ -1301,7 +1298,6 @@ void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
   // [BDX, BDY, TDY, MO(1), NO(1), MI, NI]
   // skip the first 5 iterDomains
   int64_t num_ids_to_skip = 5;
-  MmaInputSmemSwizzle swizzle = MmaInputSmemSwizzle::B128;
 
   NVF_ERROR(num_ids_to_skip >= 0);
   if (swizzle == MmaInputSmemSwizzle::None) {
diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h
index e44a6d8830b..8a16225f017 100644
--- a/csrc/scheduler/hopper_multi_matmul.h
+++ b/csrc/scheduler/hopper_multi_matmul.h
@@ -182,6 +182,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
   //! registers to shared memory.
   void scheduleStMatrixForMmaOutput(
       TensorView* tv,
+      MmaInputSmemSwizzle swizzle,
       int64_t tile_m,
       int64_t tile_n,
       int64_t tma_m,
@@ -189,7 +190,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
 
   //! Schedules the copy operation of output of a Mma op which resided in the
   //! shared memory to global memory.
-  void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n);
+  void scheduleTMAStoreForMmaOutput(
+      TensorView* tv,
+      MmaInputSmemSwizzle swizzle,
+      int64_t m,
+      int64_t n);
 
   // Map TensorView's iterDomain to its ValGroup.
   // Then, find the MatmulDimRole for the ValGroup.
diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp
index 1a25a1ebfe4..e74844d22d3 100644
--- a/tests/cpp/test_matmul_scheduler.cpp
+++ b/tests/cpp/test_matmul_scheduler.cpp
@@ -3119,8 +3119,8 @@ using HopperMatmulSchedulerTestParams = std::tuple<
     bool, // b_k_inner
     int64_t, // M
     int64_t, // N
-    int64_t // K
-    >;
+    int64_t, // K
+    MmaMacro>;
 
 std::string hopperTestName(
     const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
@@ -3128,23 +3128,42 @@ std::string hopperTestName(
   bool use_smem_epilogue;
   bool a_k_inner, b_k_inner;
   int64_t M, N, K;
-  std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = info.param;
+  MmaMacro mma_macro;
+  std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
+      info.param;
   os << (a_k_inner ? "K" : "M");
   os << (b_k_inner ? "K" : "N");
   os << "_" << M << "_" << N << "_" << K;
+  os << "_MmaMacro_" << mma_macro_to_str_map.at(mma_macro);
   if (use_smem_epilogue) {
     os << "_tma_store";
   }
   return os.str();
 }
 
+std::string hopperTestNameSwizzle(
+    const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
+  std::unordered_map<MmaMacro, std::string> mma_macro_to_swizzle_str_map = {
+      {MmaMacro::Hopper_64_256_16, "128BSwizzle"},
+      {MmaMacro::Hopper_64_128_16, "128BSwizzle"},
+      {MmaMacro::Hopper_64_64_16, "128BSwizzle"},
+      {MmaMacro::Hopper_64_32_16, "64BSwizzle"},
+      {MmaMacro::Hopper_64_16_16, "32BSwizzle"}};
+  MmaMacro mma_macro = std::get<6>(info.param);
+  std::ostringstream os;
+  os << hopperTestName(info);
+  os << "_" << mma_macro_to_swizzle_str_map.at(mma_macro);
+  return os.str();
+}
+
 class HopperMatmulSchedulerTest
     : public NVFuserFixtureParamTest<HopperMatmulSchedulerTestParams> {
  protected:
   void SetUp() {
     NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);
 
-    std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = GetParam();
+    std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
+        GetParam();
 
     if (a_k_inner) {
       layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT;
@@ -3159,14 +3178,17 @@ class HopperMatmulSchedulerTest
     // Create custom Matmul Params
     MatMulTileOptions gemm_tile;
     // TODO cta tile is a multiple of mma macro for hopper.
-    gemm_tile.cta_tile = GemmTile(128, 256, 16);
+    // Default cta_tile configuration is 2-CTA.
+    gemm_tile.cta_tile =
+        GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro));
 
     // TODO warp tile is (macroM, macroN, macroK) for hopper.
-    gemm_tile.warp_tile = GemmTile(64, 128, 16);
+    gemm_tile.warp_tile =
+        GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro));
 
     mparams.supported_vec_size = {8, 8, 4};
 
-    mparams.mma_macro = MmaMacro::Hopper_64_128_16;
+    mparams.mma_macro = mma_macro;
 
     mparams.use_smem_epilogue = use_smem_epilogue;
 
@@ -3203,6 +3225,7 @@ class HopperMatmulSchedulerTest
   bool use_smem_epilogue;
   bool a_k_inner, b_k_inner;
   int64_t M, N, K;
+  MmaMacro mma_macro;
   std::unique_ptr<Fusion> fusion_up;
   Fusion* fusion;
   std::unique_ptr<FusionGuard> fusion_guard;
@@ -3275,7 +3298,7 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
 }
 
 INSTANTIATE_TEST_SUITE_P(
-    ,
+    General,
     HopperMatmulSchedulerTest,
     testing::Combine(
         testing::Bool(), // use_smem_epilogue
@@ -3283,8 +3306,28 @@ INSTANTIATE_TEST_SUITE_P(
         testing::Bool(), // b_k_inner
         testing::Values(512), // M
         testing::Values(256), // N
-        testing::Values(64) // K
+        testing::Values(64), // K
+        testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros
         ),
     hopperTestName);
 
+INSTANTIATE_TEST_SUITE_P(
+    Swizzle,
+    HopperMatmulSchedulerTest,
+    testing::Combine(
+        testing::Values(true), // use_smem_epilogue
+        testing::Bool(), // a_k_inner
+        testing::Bool(), // b_k_inner
+        testing::Values(512), // M
+        testing::Values(256), // N
+        testing::Values(64), // K
+        testing::Values(
+            MmaMacro::Hopper_64_256_16,
+            MmaMacro::Hopper_64_128_16,
+            MmaMacro::Hopper_64_64_16,
+            MmaMacro::Hopper_64_32_16,
+            MmaMacro::Hopper_64_16_16) // mma_macros
+        ),
+    hopperTestNameSwizzle);
+
 } // namespace nvfuser
diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h
index 766a5e369c7..f835766e576 100644
--- a/tests/cpp/utils.h
+++ b/tests/cpp/utils.h
@@ -703,6 +703,40 @@ static auto kAllHopperMacros = testing::Values(
     MmaMacro::Hopper_64_248_16,
     MmaMacro::Hopper_64_256_16);
 
+static std::unordered_map<MmaMacro, std::string> mma_macro_to_str_map = {
+    {MmaMacro::Hopper_64_8_16, "m64_n8_k16"},
+    {MmaMacro::Hopper_64_16_16, "m64_n16_k16"},
+    {MmaMacro::Hopper_64_24_16, "m64_n24_k16"},
+    {MmaMacro::Hopper_64_32_16, "m64_n32_k16"},
+    {MmaMacro::Hopper_64_40_16, "m64_n40_k16"},
+    {MmaMacro::Hopper_64_48_16, "m64_n48_k16"},
+    {MmaMacro::Hopper_64_56_16, "m64_n56_k16"},
+    {MmaMacro::Hopper_64_64_16, "m64_n64_k16"},
+    {MmaMacro::Hopper_64_72_16, "m64_n72_k16"},
+    {MmaMacro::Hopper_64_80_16, "m64_n80_k16"},
+    {MmaMacro::Hopper_64_88_16, "m64_n88_k16"},
+    {MmaMacro::Hopper_64_96_16, "m64_n96_k16"},
+    {MmaMacro::Hopper_64_104_16, "m64_n104_k16"},
+    {MmaMacro::Hopper_64_112_16, "m64_n112_k16"},
+    {MmaMacro::Hopper_64_120_16, "m64_n120_k16"},
+    {MmaMacro::Hopper_64_128_16, "m64_n128_k16"},
+    {MmaMacro::Hopper_64_136_16, "m64_n136_k16"},
+    {MmaMacro::Hopper_64_144_16, "m64_n144_k16"},
+    {MmaMacro::Hopper_64_152_16, "m64_n152_k16"},
+    {MmaMacro::Hopper_64_160_16, "m64_n160_k16"},
+    {MmaMacro::Hopper_64_168_16, "m64_n168_k16"},
+    {MmaMacro::Hopper_64_176_16, "m64_n176_k16"},
+    {MmaMacro::Hopper_64_184_16, "m64_n184_k16"},
+    {MmaMacro::Hopper_64_192_16, "m64_n192_k16"},
+    {MmaMacro::Hopper_64_200_16, "m64_n200_k16"},
+    {MmaMacro::Hopper_64_208_16, "m64_n208_k16"},
+    {MmaMacro::Hopper_64_216_16, "m64_n216_k16"},
+    {MmaMacro::Hopper_64_224_16, "m64_n224_k16"},
+    {MmaMacro::Hopper_64_232_16, "m64_n232_k16"},
+    {MmaMacro::Hopper_64_240_16, "m64_n240_k16"},
+    {MmaMacro::Hopper_64_248_16, "m64_n248_k16"},
+    {MmaMacro::Hopper_64_256_16, "m64_n256_k16"}};
+
 // Utility to generate matmul input tensors based on given layout
 at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout);
 

From 4a897a417c8d14839453baa4925ae2048016b411 Mon Sep 17 00:00:00 2001
From: Jingyue Wu <wujingyue@gmail.com>
Date: Mon, 9 Dec 2024 15:30:30 -0800
Subject: [PATCH 3/3] Allgather with DID loop split (#3284)

Another baby step towards #2563
---
 csrc/multidevice/communication.cpp            |   4 +-
 csrc/multidevice/communication.h              |   5 +
 csrc/multidevice/lower_communication.cpp      |   2 +-
 csrc/multidevice/utils.cpp                    | 188 +++++++++++-------
 csrc/multidevice/utils.h                      |  13 +-
 tests/cpp/multidevice.cpp                     |  16 +-
 .../test_multidevice_lower_communication.cpp  |  67 +++++++
 tests/cpp/test_multidevice_sharding.cpp       |  48 +++++
 tests/cpp/test_multidevice_transformer.cpp    | 124 ++++++------
 9 files changed, 326 insertions(+), 141 deletions(-)

diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp
index 522b755b96d..af122ee6e3d 100644
--- a/csrc/multidevice/communication.cpp
+++ b/csrc/multidevice/communication.cpp
@@ -328,8 +328,8 @@ c10::intrusive_ptr<c10d::Work> postAllgather(
     c10d::Backend* backend,
     at::Tensor input_tensor,
     at::Tensor output_tensor) {
-  auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0);
-  assertBufferCount(splits, communication->team().size());
+  auto splits =
+      at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0);
   assertBuffersHaveSameSize({input_tensor}, splits);
 
   // allgather primitive in c10d induces extra buffering time to copy out the
diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h
index 45c104b36d3..8631a1a04e5 100644
--- a/csrc/multidevice/communication.h
+++ b/csrc/multidevice/communication.h
@@ -90,6 +90,11 @@ class Communication : public Expr {
     return attribute<Team>(1);
   }
 
+  // A convenience helper so the user doesn't need to convert size_t to int64_t.
+  int64_t team_size() const {
+    return static_cast<int64_t>(team().size());
+  }
+
   DeviceIdxType root() const {
     return attribute<DeviceIdxType>(2);
   }
diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp
index c8068b5a113..4b878ac7376 100644
--- a/csrc/multidevice/lower_communication.cpp
+++ b/csrc/multidevice/lower_communication.cpp
@@ -196,7 +196,7 @@ void lowerToReduceScatter(
     std::vector<Communication*>& comms) {
   const DeviceMesh& mesh = input_tv->getDeviceMesh();
   auto reduction_axis = output_tv->getReductionAxis().value();
-  auto scattered_axis = getShardedAxis(output_tv);
+  auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
   // The output tensor is sharded on scattered_axis and needs to be mapped
   // back onto the input. The input has an reduced axis, so the scattered axis
   // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp
index 24b7e582104..54f1303bc16 100644
--- a/csrc/multidevice/utils.cpp
+++ b/csrc/multidevice/utils.cpp
@@ -121,48 +121,133 @@ bool isSharded(const TensorView* tv) {
   return is_sharded;
 }
 
-std::vector<int64_t> unshardedSizes(
-    const TensorView* tv,
-    c10::IntArrayRef sizes) {
-  std::vector<int64_t> unsharded_sizes = sizes.vec();
-
-  for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
-    const ParallelType parallel_type = alloc_id->getParallelType();
+namespace {
+// Collect device-parallel IterDomains in `domain` and return them as a
+// ParallelType-to-IterDomain map.
+std::unordered_map<ParallelType, IterDomain*> mapDeviceParallelTypeToId(
+    const std::vector<IterDomain*>& domain) {
+  std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
+  parallel_type_to_id.reserve(kParallelTypeDIDs.size());
+  for (IterDomain* id : domain) {
+    const ParallelType parallel_type = id->getParallelType();
     if (!isParallelTypeDeviceDim(parallel_type)) {
       continue;
     }
 
-    const auto inputs = IterVisitor::getInputsTo(
-        {alloc_id},
-        {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
     NVF_ERROR(
-        !inputs.empty(),
-        "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
-    NVF_ERROR(
-        inputs.size() == 1,
-        "Failed to find the single logical input to ",
-        alloc_id,
-        ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
-        toDelimitedString(inputs));
-
-    const auto iter = std::find(
-        tv->getLogicalDomain().begin(),
-        tv->getLogicalDomain().end(),
-        inputs[0]);
+        parallel_type_to_id.try_emplace(parallel_type, id).second,
+        "Found multiple loop IterDomains with the same parallel type (",
+        parallel_type,
+        "): ",
+        toDelimitedString(domain));
+  }
+  return parallel_type_to_id;
+}
+
+std::unordered_map<IterDomain*, int64_t> mapIterDomainToTensorAxis(
+    const std::vector<IterDomain*>& domain) {
+  std::unordered_map<IterDomain*, int64_t> id_to_axis;
+  int64_t axis = 0;
+  for (auto* id : domain) {
+    // Reduction IterDomains are not materialized as an at::Tensor axis.
+    if (id->isReduction()) {
+      continue;
+    }
+    id_to_axis[id] = axis;
+    axis++;
+  }
+  return id_to_axis;
+}
+
+} // namespace
+
+int64_t getShardedLogicalAxis(
+    const TensorView* tv,
+    const ParallelType parallel_type) {
+  std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
+      mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
+  IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
+  if (alloc_id == nullptr) {
+    return -1;
+  }
+
+  std::unordered_map<IterDomain*, int64_t> logical_id_to_axis =
+      mapIterDomainToTensorAxis(tv->getLogicalDomain());
+  IterDomain* id = alloc_id;
+  while (logical_id_to_axis.count(id) == 0) {
+    Expr* def = id->definition();
     NVF_ERROR(
-        iter != tv->getLogicalDomain().end(),
-        "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
-        inputs[0]);
-
-    // Count the number of non-reduction IterDomains before `iter`. Reduction
-    // IterDomains are not materialized in the at::Tensor's shape.
-    const auto index = std::count_if(
-        tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
-          return !id->isReduction();
-        });
-    unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
+        def != nullptr,
+        "Failed to find a non-reduction logical IterDomain that produces ",
+        alloc_id);
+    if (auto* split = dynamic_cast<Split*>(def)) {
+      // Returning just which tensor axis is sharded isn't sufficient to let
+      // shardTensor, a user of this function, know how to shard the tensor.
+      // For example,
+      //
+      //   t = makeContigConcreteTensor({6});
+      //   t->split(0, 2, /*inner_split=*/true);
+      //   t->axis(-1)->parallelize(DIDx);
+      //   // [i{3}, iDIDx{2}]
+      //
+      // and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the
+      // stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3,
+      // 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3,
+      // 4, 5], assuming the axis is sharded outermost.
+      //
+      // One potential way to solve the general problem is to replay and rewind
+      // the splits on the at::Tensor.  For example,
+      //
+      //   t = makeContigConcreteTensor({30});
+      //   t->split(0, 5);
+      //   t->split(0, 3);
+      //   t->axis(0)->parallelize(Host);
+      //   t->axis(1)->parallelize(DIDx);
+      //   // [iHost{2}, iDIDx{3}, i{5}]
+      //
+      // Given an unsharded at::Tensor of shape [30], we'll first replay the
+      // splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we
+      // `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then,
+      // we rewind the splits (and therefore apply merging) using
+      // `torch.reshape` to get a sharded tensor of shape [10].
+      NVF_ERROR(
+          split->outer() == id,
+          "Currently, we don't support DID on inner splits: ",
+          split);
+      id = split->in();
+    } else if (auto* merge = dynamic_cast<Merge*>(def)) {
+      // For example,
+      //
+      //   t = makeContigTensor(2);
+      //   t->merge(0, 1);
+      //   t->axis(0)->parallelize(DIDx);
+      //
+      // When `unshardedSizes` is given a local tensor of shape [1, 1], it's
+      // unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc.
+      NVF_THROW(
+          "Failed to attribute the sharding to a single tensor axis and therefore bailed out: ",
+          merge);
+    } else {
+      NVF_THROW(
+          "Unexpected transforms from logical to a DID-parallel allocation IterDomain: ",
+          def);
+    }
   }
 
+  return logical_id_to_axis.at(id);
+}
+
+std::vector<int64_t> unshardedSizes(
+    const TensorView* tv,
+    c10::IntArrayRef sizes) {
+  std::vector<int64_t> unsharded_sizes = sizes.vec();
+  for (ParallelType parallel_type : kParallelTypeDIDs) {
+    const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type);
+    if (sharded_axis == -1) {
+      continue;
+    }
+    unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type);
+  }
   return unsharded_sizes;
 }
 
@@ -174,27 +259,6 @@ int64_t numDeviceDims(const TensorView* tv) {
 }
 
 namespace {
-// Collect device-parallel IterDomains in `loop_domain` and return them as a
-// ParallelType-to-IterDomain map.
-std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId(
-    const std::vector<IterDomain*>& loop_domain) {
-  std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id;
-  parallel_type_to_id.reserve(kParallelTypeDIDs.size());
-  for (IterDomain* loop_id : loop_domain) {
-    const ParallelType parallel_type = loop_id->getParallelType();
-    if (!isParallelTypeDeviceDim(parallel_type)) {
-      continue;
-    }
-
-    NVF_ERROR(
-        parallel_type_to_id.try_emplace(parallel_type, loop_id).second,
-        "Found multiple loop IterDomains with the same parallel type (",
-        parallel_type,
-        "): ",
-        toDelimitedString(loop_domain));
-  }
-  return parallel_type_to_id;
-}
 
 std::vector<IterDomain*> getInputsInTargetDomain(
     IterDomain* loop_id,
@@ -294,9 +358,9 @@ bool haveDifferentShardings(
   // 3. Check if the two loop IterDomains are almost-exactly mapped in the
   // IdModel.
   std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id =
-      mapParallelTypeToId(producer->getLoopDomain());
+      mapDeviceParallelTypeToId(producer->getLoopDomain());
   std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id =
-      mapParallelTypeToId(consumer->getLoopDomain());
+      mapDeviceParallelTypeToId(consumer->getLoopDomain());
 
   for (const auto parallel_type : kParallelTypeDIDs) {
     IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type);
@@ -502,16 +566,6 @@ std::set<DeviceIdxType> involvedDevices(Expr* expr) {
   return ret;
 }
 
-int64_t getShardedAxis(TensorView* tv) {
-  auto ids = TensorDomain::noReductions(tv->getLogicalDomain());
-  for (size_t i = 0; i < ids.size(); ++i) {
-    if (ids[i]->getParallelType() == ParallelType::DIDx) {
-      return static_cast<int64_t>(i);
-    }
-  }
-  return -1;
-}
-
 void reorderDIDToFront(TensorView* tv) {
   // new position to old position
   std::unordered_map<int64_t, int64_t> order_map;
diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h
index 5be2e11bd15..ef88fbdcf80 100644
--- a/csrc/multidevice/utils.h
+++ b/csrc/multidevice/utils.h
@@ -123,9 +123,16 @@ int64_t requestedNumberOfDevices(Fusion*);
 void unshard(Fusion*);
 void unshard(TensorView*);
 
-// Returns the index of the a sharded axis if none return -1.
-// TODO: Assumes no merges/splits on sharded axis.
-int64_t getShardedAxis(TensorView*);
+// Returns the index of the sharded logical axis that produces the allocation
+// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel
+// type, returns -1.
+//
+// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by
+// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and
+// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in
+// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's
+// extent if that IterDomain is sharded.
+int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type);
 
 // Reorders a TensorView so that the DID parallelized axis are in front.
 void reorderDIDToFront(TensorView*);
diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp
index bab5cdccc5e..22897dc5311 100644
--- a/tests/cpp/multidevice.cpp
+++ b/tests/cpp/multidevice.cpp
@@ -128,7 +128,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) {
     return tensor;
   }
   NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv);
-  return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh());
+  return shardTensor(
+      tensor,
+      getShardedLogicalAxis(tv, ParallelType::DIDx),
+      tv->getDeviceMesh());
 }
 
 at::Tensor MultiDeviceTest::shardTensor(
@@ -144,13 +147,10 @@ at::Tensor MultiDeviceTest::shardTensor(
   auto stride = extent / nslices;
   // TODO: returning slice 0 temporarily when device is not in the mesh.
   i = (i < 0) ? 0 : i;
-  auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
-  // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx
-  // axis in front representing the sharded extent of the tensor.
-  if (stride > 1) {
-    slice = slice.unsqueeze(0);
-  }
-  return slice;
+  // The following slicing is problematic when DID is on an inner split (cf.
+  // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and
+  // it's enforced by getShardedLogicalAxis.
+  return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous();
 }
 
 } // namespace nvfuser
diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp
index 643b5b2220d..d1f06d80e1d 100644
--- a/tests/cpp/test_multidevice_lower_communication.cpp
+++ b/tests/cpp/test_multidevice_lower_communication.cpp
@@ -202,6 +202,73 @@ TEST_F(LowerCollectiveTest, Allgather) {
   EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor));
 }
 
+TEST_F(LowerCollectiveTest, Allgather_LoopSplit) {
+  auto fusion = std::make_unique<Fusion>();
+  FusionGuard fg(fusion.get());
+
+  const auto num_devices = communicator_->size();
+  auto mesh = DeviceMesh::createForNumDevices(num_devices);
+
+  TensorView* in = makeContigTensor(1);
+  in->setDeviceMesh(mesh);
+  TensorView* out = set(in);
+  fusion->addInput(in);
+  fusion->addOutput(out);
+
+  in->split(0, num_devices, /*inner_split=*/false);
+  in->axis(0)->parallelize(ParallelType::DIDx);
+  in->setAllocationDomain(in->getLoopDomain(), true);
+
+  out->split(0, num_devices, /*inner_split=*/false);
+  out->setAllocationDomain(out->getLoopDomain(), true);
+
+  at::Tensor unsharded_tensor =
+      at::randn({num_devices * kTensorSize}, at::kFloat);
+  at::Tensor in_tensor =
+      shardTensor(unsharded_tensor, in).to(communicator_->device());
+
+  FusionExecutorCache fec(std::move(fusion));
+  at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
+  assertIsCompiledToHostIrContainer(fec);
+
+  EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
+}
+
+// This currently fails due to getShardingChanges reads root/logical only:
+// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77.
+// Will try to fix this in a follow-up PR and reenable the test.
+TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) {
+  auto fusion = std::make_unique<Fusion>();
+  FusionGuard fg(fusion.get());
+
+  const auto num_devices = communicator_->size();
+  auto mesh = DeviceMesh::createForNumDevices(num_devices);
+
+  TensorView* in = makeContigTensor(2);
+  in->setDeviceMesh(mesh);
+  TensorView* out = set(in);
+  fusion->addInput(in);
+  fusion->addOutput(out);
+
+  in->split(1, num_devices, /*inner_split=*/false);
+  in->axis(1)->parallelize(ParallelType::DIDx);
+  in->setAllocationDomain(in->getLoopDomain(), true);
+
+  out->split(1, num_devices, /*inner_split=*/false);
+  out->setAllocationDomain(out->getLoopDomain(), true);
+
+  at::Tensor unsharded_tensor =
+      at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3});
+  at::Tensor in_tensor =
+      shardTensor(unsharded_tensor, in).to(communicator_->device());
+
+  FusionExecutorCache fec(std::move(fusion));
+  at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0];
+  assertIsCompiledToHostIrContainer(fec);
+
+  EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor));
+}
+
 TEST_F(LowerCollectiveTest, Broadcast) {
   auto fusion = std::make_unique<Fusion>();
   FusionGuard fg(fusion.get());
diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp
index 3adac90bc5e..aaa5d3a3218 100644
--- a/tests/cpp/test_multidevice_sharding.cpp
+++ b/tests/cpp/test_multidevice_sharding.cpp
@@ -491,4 +491,52 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) {
 
 INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool());
 
+TEST_F(MultiDeviceTest, ShardTensor_OuterSplit) {
+  const int d = communicator_->size();
+
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  TensorView* tv = makeContigConcreteTensor({2, d * 3});
+  tv->setDeviceMesh(DeviceMesh::createForNumDevices(d));
+  tv->split(1, d, /*inner_split=*/false);
+  tv->axis(1)->parallelize(ParallelType::DIDx);
+  tv->setAllocationDomain(tv->getLoopDomain(), true);
+
+  fusion.addInput(tv);
+  fusion.addOutput(tv);
+
+  at::Tensor unsharded = at::arange(2 * d * 3).view({2, d * 3});
+  at::Tensor sharded = shardTensor(unsharded, tv);
+
+  EXPECT_THAT(sharded.sizes(), ElementsAre(2, 3));
+  at::Tensor expected = unsharded.view({2, d, 3}).index(
+      {torch::indexing::Slice(),
+       communicator_->deviceId(),
+       torch::indexing::Slice()});
+  EXPECT_TRUE(at::equal(sharded, expected));
+}
+
+TEST_F(MultiDeviceTest, ShardTensor_InnerSplit) {
+  const int d = communicator_->size();
+
+  Fusion fusion;
+  FusionGuard fg(&fusion);
+
+  TensorView* tv = makeContigConcreteTensor({d * 3});
+  tv->setDeviceMesh(DeviceMesh::createForNumDevices(d));
+  tv->split(0, d, /*inner_split=*/true);
+  tv->axis(-1)->parallelize(ParallelType::DIDx);
+  tv->setAllocationDomain(tv->getLoopDomain(), true);
+
+  fusion.addInput(tv);
+  fusion.addOutput(tv);
+
+  at::Tensor unsharded = at::arange(d * 3);
+  EXPECT_THAT(
+      [&]() { shardTensor(unsharded, tv); },
+      ::testing::ThrowsMessage<nvfuser::nvfError>(
+          ::testing::HasSubstr("DID on inner splits")));
+}
+
 } // namespace nvfuser
diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp
index 2ef33dcdf8f..0f39ae6f6e5 100644
--- a/tests/cpp/test_multidevice_transformer.cpp
+++ b/tests/cpp/test_multidevice_transformer.cpp
@@ -720,14 +720,14 @@ TEST_P(DistributedTransformerTest, MLP_Layer) {
 
   std::vector<c10::IValue> inputs = {
       x,
-      shardTensor(w0, 0, mesh),
-      shardTensor(b0, 0, mesh),
-      shardTensor(w1, 1, mesh),
+      shardTensor(w0, 0, mesh).unsqueeze(0),
+      shardTensor(b0, 0, mesh).unsqueeze(0),
+      shardTensor(w1, 1, mesh).unsqueeze(0),
       b1};
 
   std::vector<at::Tensor> expected_outputs = {
-      shardTensor(reference_outs[0], 1, mesh),
-      shardTensor(reference_outs[1], 1, mesh),
+      shardTensor(reference_outs[0], 1, mesh).unsqueeze(0),
+      shardTensor(reference_outs[1], 1, mesh).unsqueeze(0),
       reference_outs[2],
       reference_outs[3]};
 
@@ -801,17 +801,17 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) {
   auto mask_ = reference_outs[4];
 
   std::vector<c10::IValue> inputs = {
-      shardTensor(x_, 0, mesh),
-      shardTensor(w0_, 0, mesh),
-      shardTensor(b0_, 0, mesh),
-      shardTensor(w1_, 1, mesh),
+      shardTensor(x_, 0, mesh).unsqueeze(0),
+      shardTensor(w0_, 0, mesh).unsqueeze(0),
+      shardTensor(b0_, 0, mesh).unsqueeze(0),
+      shardTensor(w1_, 1, mesh).unsqueeze(0),
       b1_};
 
   std::vector<at::Tensor> expected_outputs = {
-      shardTensor(reference_outs[0], 1, mesh),
-      shardTensor(reference_outs[1], 1, mesh),
-      shardTensor(reference_outs[2], 0, mesh),
-      shardTensor(reference_outs[3], 0, mesh)};
+      shardTensor(reference_outs[0], 1, mesh).unsqueeze(0),
+      shardTensor(reference_outs[1], 1, mesh).unsqueeze(0),
+      shardTensor(reference_outs[2], 0, mesh).unsqueeze(0),
+      shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)};
 
   FusionExecutorCache executor_cache(std::move(fusion));
   at::manual_seed(getATenRandomSeed());
@@ -866,12 +866,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) {
       x,
       shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
       shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}),
-      shardTensor(w1, 1, mesh),
+      shardTensor(w1, 1, mesh).unsqueeze(0),
       b1};
   std::vector<at::Tensor> expected_outputs = {
       shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh)
           .view({1, B * S, 3 * E / D}),
-      shardTensor(reference_outs[1], 1, mesh),
+      shardTensor(reference_outs[1], 1, mesh).unsqueeze(0),
       reference_outs[2],
       reference_outs[3]};
 
@@ -929,17 +929,17 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) {
   at::manual_seed(getATenRandomSeed());
   auto reference_outs = reference_mha(x, w0, b0, w1, b1);
   std::vector<c10::IValue> inputs = {
-      shardTensor(x, 0, mesh),
+      shardTensor(x, 0, mesh).unsqueeze(0),
       shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
       shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}),
-      shardTensor(w1, 1, mesh),
+      shardTensor(w1, 1, mesh).unsqueeze(0),
       b1};
   std::vector<at::Tensor> expected_outputs = {
       shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh)
           .view({1, B * S, 3 * E / D}),
-      shardTensor(reference_outs[1], 1, mesh),
-      shardTensor(reference_outs[2], 0, mesh),
-      shardTensor(reference_outs[3], 0, mesh)};
+      shardTensor(reference_outs[1], 1, mesh).unsqueeze(0),
+      shardTensor(reference_outs[2], 0, mesh).unsqueeze(0),
+      shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)};
 
   FusionExecutorCache fec(std::move(fusion));
   at::manual_seed(getATenRandomSeed());
@@ -1003,16 +1003,16 @@ TEST_P(DistributedTransformerTest, MLP_Backward) {
       grad_,
       x_,
       mask_,
-      shardTensor(mlp_w0_, 0, mesh),
-      shardTensor(mlp_w1_, 1, mesh),
-      shardTensor(linear0_, 1, mesh)};
+      shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
+      shardTensor(linear0_, 1, mesh).unsqueeze(0)};
   std::vector<at::Tensor> expected_outputs = {
       outs[0], // dropout grad
-      shardTensor(outs[1], 1, mesh), // linear1 weight grad
+      shardTensor(outs[1], 1, mesh).unsqueeze(0), // linear1 weight grad
       outs[2], // linear1 bias grad
-      shardTensor(outs[3], 1, mesh), // gelu grad
-      shardTensor(outs[4], 0, mesh), // linear0 weight grad
-      shardTensor(outs[5], 0, mesh), // linear0 bias grad
+      shardTensor(outs[3], 1, mesh).unsqueeze(0), // gelu grad
+      shardTensor(outs[4], 0, mesh).unsqueeze(0), // linear0 weight grad
+      shardTensor(outs[5], 0, mesh).unsqueeze(0), // linear0 bias grad
       outs[6]}; // linear0 grad x
 
   FusionExecutorCache executor_cache(std::move(fusion));
@@ -1094,22 +1094,23 @@ TEST_P(DistributedTransformerTest, MHA_Backward) {
   std::vector<c10::IValue> inputs = {
       x,
       shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
-      shardTensor(w1, 1, mesh),
+      shardTensor(w1, 1, mesh).unsqueeze(0),
       grad,
       mask,
-      shardTensor(reference_outs[0], 1, mesh), // sdpa.output
-      shardTensor(reference_outs[1], 1, mesh), // sdpa.log_sumexp
+      shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), // sdpa.output
+      shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), // sdpa.log_sumexp
       reference_outs[2], // sdpa.seed
       reference_outs[3], // sdpa.offset
-      shardTensor(reference_outs[13], 1, mesh) // linear0
+      shardTensor(reference_outs[13], 1, mesh).unsqueeze(0) // linear0
   };
   std::vector<at::Tensor> expected_outputs = {
       reference_outs[4], // dropout grad
-      shardTensor(reference_outs[5], 1, mesh), // linear1 weight grad
+      shardTensor(reference_outs[5], 1, mesh)
+          .unsqueeze(0), // linear1 weight grad
       reference_outs[6], // linear1 bias grad
-      shardTensor(reference_outs[7], 1, mesh), // q grad
-      shardTensor(reference_outs[8], 1, mesh), // k grad
-      shardTensor(reference_outs[9], 1, mesh), // v grad
+      shardTensor(reference_outs[7], 1, mesh).unsqueeze(0), // q grad
+      shardTensor(reference_outs[8], 1, mesh).unsqueeze(0), // k grad
+      shardTensor(reference_outs[9], 1, mesh).unsqueeze(0), // v grad
       shardTensor(reference_outs[10].view({3, E, E}), 1, mesh)
           .view({1, 3 * E / D, E}), // linear0 weight grad
       shardTensor(reference_outs[11].view({3, E}), 1, mesh)
@@ -1234,26 +1235,26 @@ TEST_P(DistributedTransformerTest, Forward_SP) {
   auto at_out = (resid0_ + mlp_out_).to(at_dtype);
 
   std::vector<c10::IValue> inputs = {
-      shardTensor(x_, 0, mesh),
+      shardTensor(x_, 0, mesh).unsqueeze(0),
       ln0_w_,
       ln0_b_,
       shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
       shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}),
-      shardTensor(mha_w1_, 1, mesh),
+      shardTensor(mha_w1_, 1, mesh).unsqueeze(0),
       mha_b1_,
       ln1_w_,
       ln1_b_,
-      shardTensor(mlp_w0_, 0, mesh),
-      shardTensor(mlp_b0_, 0, mesh),
-      shardTensor(mlp_w1_, 1, mesh),
+      shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_b0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
       mlp_b1_};
 
   std::vector<at::Tensor> expected_outputs = {
-      shardTensor(ln0_out_, 0, mesh),
-      shardTensor(mha_out_, 0, mesh),
-      shardTensor(ln1_out_, 0, mesh),
-      shardTensor(mlp_out_, 0, mesh),
-      shardTensor(at_out, 0, mesh)};
+      shardTensor(ln0_out_, 0, mesh).unsqueeze(0),
+      shardTensor(mha_out_, 0, mesh).unsqueeze(0),
+      shardTensor(ln1_out_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_out_, 0, mesh).unsqueeze(0),
+      shardTensor(at_out, 0, mesh).unsqueeze(0)};
 
   FusionExecutorCache fec(std::move(fusion));
   at::manual_seed(getATenRandomSeed());
@@ -1367,13 +1368,13 @@ TEST_P(DistributedTransformerTest, Forward) {
       ln0_b_,
       shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
       shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}),
-      shardTensor(mha_w1_, 1, mesh),
+      shardTensor(mha_w1_, 1, mesh).unsqueeze(0),
       mha_b1_,
       ln1_w_,
       ln1_b_,
-      shardTensor(mlp_w0_, 0, mesh),
-      shardTensor(mlp_b0_, 0, mesh),
-      shardTensor(mlp_w1_, 1, mesh),
+      shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_b0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
       mlp_b1_};
 
   std::vector<at::Tensor> expected_outputs = {
@@ -1620,13 +1621,16 @@ TEST_P(DistributedTransformerTest, Backward) {
   auto dx_ = (ln0_x_grad_ + resid1_grad_).to(at_dtype);
 
   auto expected_outputs = {
-      shardTensor(mlp_grads_[1], 1, mesh), // mlp_linear1_weight_grad
+      shardTensor(mlp_grads_[1], 1, mesh)
+          .unsqueeze(0), // mlp_linear1_weight_grad
       mlp_grads_[2], // mlp_linear1_bias_grad
-      shardTensor(mlp_grads_[4], 0, mesh), // mlp_linear0_weight_grad
-      shardTensor(mlp_grads_[5], 0, mesh), // mlp_linear0_bias_grad
+      shardTensor(mlp_grads_[4], 0, mesh)
+          .unsqueeze(0), // mlp_linear0_weight_grad
+      shardTensor(mlp_grads_[5], 0, mesh).unsqueeze(0), // mlp_linear0_bias_grad
       ln1_w_grad_,
       ln1_b_grad_,
-      shardTensor(mha_grads_[5], 1, mesh), // mha linear1 weight grad
+      shardTensor(mha_grads_[5], 1, mesh)
+          .unsqueeze(0), // mha linear1 weight grad
       mha_grads_[6], // mha linear1 bias grad
       shardTensor(
           mha_grads_[10].view({3, E, E}), 1, mesh) // failing starting here
@@ -1641,13 +1645,13 @@ TEST_P(DistributedTransformerTest, Backward) {
       x_,
       grad_,
       shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}),
-      shardTensor(mha_w1_, 1, mesh),
-      shardTensor(mlp_w0_, 0, mesh),
-      shardTensor(mlp_w1_, 1, mesh),
+      shardTensor(mha_w1_, 1, mesh).unsqueeze(0),
+      shardTensor(mlp_w0_, 0, mesh).unsqueeze(0),
+      shardTensor(mlp_w1_, 1, mesh).unsqueeze(0),
       mlp_out_[4], // mlp dropout mask
       mha_out_[4], // mha dropout mask
-      shardTensor(mha_grads_[0], 1, mesh), // sdpa output
-      shardTensor(mha_grads_[1], 1, mesh), // sdpa logsum_exp
+      shardTensor(mha_grads_[0], 1, mesh).unsqueeze(0), // sdpa output
+      shardTensor(mha_grads_[1], 1, mesh).unsqueeze(0), // sdpa logsum_exp
       mha_grads_[2], // sdpa seed
       mha_grads_[3], // sdpa offset
       ln1_w_,
@@ -1658,9 +1662,9 @@ TEST_P(DistributedTransformerTest, Backward) {
       ln0_b_,
       ln0_mean_,
       ln0_rstd_,
-      shardTensor(mha_out_[0], 1, mesh), // mha linear0
+      shardTensor(mha_out_[0], 1, mesh).unsqueeze(0), // mha linear0
       mha_out_[2].to(at::kFloat), // mha linear1
-      shardTensor(mlp_out_[0], 1, mesh) // mlp linear1
+      shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1
   };
 
   FusionExecutorCache executor_cache(std::move(fusion));