Skip to content

Commit

Permalink
Merge pull request #7 from unixpickle/fused-backward
Browse files Browse the repository at this point in the history
Fused backward pass kernel
  • Loading branch information
proger authored Mar 21, 2024
2 parents 5cb9403 + f47c2f1 commit 0e12f07
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 51 deletions.
167 changes: 123 additions & 44 deletions accelerated_scan/warp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#define CHECK_STRIDE(x) TORCH_CHECK(x.stride(-1) == 1 || x.size(-1) == 1);

template<typename weight_t, int N>
class UnalignedTuple {
public:
Expand All @@ -26,11 +28,33 @@ template<typename T, int N>
class alignas(16) AlignedTuple : public UnalignedTuple<T, N> {
};

template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
template <typename Tuple, int offset>
__device__ Tuple load_shifted_tuple(const Tuple* ptr, int index, int minIdx, int maxIdx) {
using weight_t = typename Tuple::Type;

const weight_t* rawPtr = reinterpret_cast<const weight_t *>(ptr);
Tuple x;
for (int i = 0; i < Tuple::Size; i++) {
const int idx = index * Tuple::Size + i + offset;
if (idx >= minIdx * Tuple::Size && idx < maxIdx * Tuple::Size) {
x.data[i] = rawPtr[idx];
} else {
x.data[i] = 0.0;
}
}

return x;
}

template <typename Tuple, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence, bool backward>
__global__ void scan(
const Tuple* gates,
const Tuple* tokens,
Tuple* result,
// Only passed if backward is True.
const Tuple* output,
Tuple* gateGradOut,
// Shape information
const int batch_stride,
const int dim_stride,
const bool reverse
Expand All @@ -51,6 +75,10 @@ __global__ void scan(
const weight_t kEmptyGate = 1.0;
const weight_t kEmptyToken = 0.0;

// Limits for loading shifted tuples during backward pass.
const int minIdx = seqoffset / Tuple::Size;
const int maxIdx = minIdx + blockDim.x * kNChunksPerSequence;

//
// Read from global memory.
// Scan sequentially in thread registers (level 0).
Expand All @@ -64,7 +92,12 @@ __global__ void scan(
__syncthreads();
}

Tuple loadedGate = gates[tupleOffset];
Tuple loadedGate;
if (backward) {
loadedGate = load_shifted_tuple<Tuple, 1>(gates, tupleOffset, minIdx, maxIdx);
} else {
loadedGate = gates[tupleOffset];
}
Tuple loadedToken = tokens[tupleOffset];
if (reverse) {
loadedGate.reverse();
Expand Down Expand Up @@ -174,43 +207,68 @@ __global__ void scan(
}
result[tupleOffset] = accToken;

if (backward) {
Tuple gateGrad = load_shifted_tuple<Tuple, -1>(output, tupleOffset, minIdx, maxIdx);
for (int i = 0; i < Tuple::Size; i++) {
gateGrad.data[i] = gateGrad.data[i] * accToken.data[i];
}
gateGradOut[tupleOffset] = gateGrad;
}

if (laneId == kWarpLast && warpId == kBlockLast) {
chunkAccGate = accGate.data[kThreadLast];
chunkAccToken = accToken.data[kThreadLast];
}
}
}

#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, batch_stride, dim_stride, reverse) \
#define DISPATCH_SCAN_INNER(TupleT, backward, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
scan<TupleT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, backward><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const TupleT *>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const TupleT *>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<TupleT *>(out.data_ptr<torch_weight_t>()), \
reinterpret_cast<const TupleT *>(output), \
reinterpret_cast<TupleT *>(gateGradOut), \
batch_stride, dim_stride, reverse \
);

#define DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse) \
using AlignedT = AlignedTuple<weight_t, kNStepsPerThread>; \
using UnalignedT = UnalignedTuple<weight_t, kNStepsPerThread>; \
if (kNStepsPerThread == 4 && \
((long)gates.data_ptr()) % 16 == 0 && \
((long)tokens.data_ptr()) % 16 == 0 && \
((long)out.data_ptr()) % 16 == 0) { \
scan<AlignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const AlignedT *>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const AlignedT *>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<AlignedT *>(out.data_ptr<torch_weight_t>()), \
batch_stride, dim_stride, reverse \
); \
((long)out.data_ptr()) % 16 == 0 && \
((long)output) % 16 == 0 && \
((long)gateGradOut) % 16 == 0) { \
if (output) { \
DISPATCH_SCAN_INNER(AlignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} else { \
DISPATCH_SCAN_INNER(AlignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} \
} else { \
scan<UnalignedT, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>( \
reinterpret_cast<const UnalignedT*>(gates.data_ptr<torch_weight_t>()), \
reinterpret_cast<const UnalignedT*>(tokens.data_ptr<torch_weight_t>()), \
reinterpret_cast<UnalignedT *>(out.data_ptr<torch_weight_t>()), \
batch_stride, dim_stride, reverse \
); \
if (output) { \
DISPATCH_SCAN_INNER(UnalignedT, true, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} else { \
DISPATCH_SCAN_INNER(UnalignedT, false, weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut, batch_stride, dim_stride, reverse); \
} \
}

template <typename weight_t, typename torch_weight_t>
void
warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
warpscan(
const at::Tensor &gates,
const at::Tensor &tokens,
const at::Tensor &out,
const void *output,
void *gateGradOut,
const bool reverse
) {
const auto strides = tokens.strides();
const int batch_stride = strides[0];
const int dim_stride = strides[1];
TORCH_CHECK(tokens.stride(-1) == 1 || tokens.size(-1) == 1);
TORCH_CHECK(gates.stride(-1) == 1 || gates.size(-1) == 1);
CHECK_STRIDE(tokens);
CHECK_STRIDE(gates);

const auto sizes = tokens.sizes();
const int batch_size = sizes[0];
Expand All @@ -227,119 +285,140 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 64) {
constexpr int kNStepsPerThread = 2;
constexpr int kNWarpsPerBlock = 1;
constexpr int kNChunksPerSequence = 1;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 128) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 4;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 256) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 8;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 512) {
constexpr int kNStepsPerThread = 1;
constexpr int kNWarpsPerBlock = 16;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 1024) {
constexpr int kNStepsPerThread = 2;
constexpr int kNWarpsPerBlock = 16;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 2048) {
constexpr int kNStepsPerThread = 2;
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 4096) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 8192) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 2;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 16384) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 4;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 32768) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 8;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else if (seqlen == 65536) {
constexpr int kNStepsPerThread = 4;
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 16;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
DISPATCH_SCAN(weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out,
kNChunksPerSequence, grid, kNThreads, stream, gates, tokens, out, output, gateGradOut,
batch_stride, dim_stride, reverse);
} else {
TORCH_CHECK(false && "seqlen must be a power of 2, >= 32, <= 65536");
}
}

#define DISPATCH_WARPSCAN(gates, ...) \
if (gates.scalar_type() == at::ScalarType::BFloat16) { \
warpscan<__nv_bfloat16, at::BFloat16>(gates, __VA_ARGS__); \
} else if (gates.scalar_type() == at::ScalarType::Half) { \
warpscan<__half, at::Half>(gates, __VA_ARGS__); \
} else if (gates.scalar_type() == at::ScalarType::Float) { \
warpscan<float, float>(gates, __VA_ARGS__); \
} else { \
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32"); \
}

at::Tensor
warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
TORCH_CHECK(tokens.is_cuda());
TORCH_CHECK(gates.is_cuda());
TORCH_CHECK(tokens.is_contiguous());
TORCH_CHECK(gates.is_contiguous());
TORCH_CHECK(tokens.scalar_type() == gates.scalar_type());
TORCH_CHECK(tokens.scalar_type() == out.scalar_type());

if (tokens.scalar_type() == at::ScalarType::BFloat16) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::BFloat16);
warpscan<__nv_bfloat16, at::BFloat16>(gates, tokens, out, reverse);
} else if (tokens.scalar_type() == at::ScalarType::Half) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Half);
warpscan<__half, at::Half>(gates, tokens, out, reverse);
} else if (tokens.scalar_type() == at::ScalarType::Float) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Float);
warpscan<float, float>(gates, tokens, out, reverse);
} else {
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32");
}
DISPATCH_WARPSCAN(gates, tokens, out, nullptr, nullptr, reverse);
return out;
}
}

void
warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& tokenGradOut) {
TORCH_CHECK(gates.is_cuda());
TORCH_CHECK(output.is_cuda());
TORCH_CHECK(outGrad.is_cuda());
TORCH_CHECK(gateGradOut.is_contiguous());
TORCH_CHECK(tokenGradOut.is_contiguous());
TORCH_CHECK(gates.scalar_type() == output.scalar_type());
TORCH_CHECK(gates.scalar_type() == outGrad.scalar_type());
TORCH_CHECK(gates.scalar_type() == gateGradOut.scalar_type());
TORCH_CHECK(gates.scalar_type() == tokenGradOut.scalar_type());
TORCH_CHECK(gates.sizes() == output.sizes());
TORCH_CHECK(gates.sizes() == outGrad.sizes());
TORCH_CHECK(gates.sizes() == gateGradOut.sizes());
TORCH_CHECK(gates.sizes() == tokenGradOut.sizes());

DISPATCH_WARPSCAN(gates, outGrad, tokenGradOut, output.data_ptr(), gateGradOut.data_ptr(), true);
}
13 changes: 6 additions & 7 deletions accelerated_scan/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@

cpp_source = """
at::Tensor warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse);
void warpscan_backward(const at::Tensor &gates, const at::Tensor &output, const at::Tensor &outGrad, const at::Tensor& gateGradOut, const at::Tensor& valueGradOut);
"""

module = load_inline(
name='warpscan',
cpp_sources=[cpp_source],
cuda_sources=[cuda_source],
functions=['warpscan_forward'],
functions=['warpscan_forward', 'warpscan_backward'],
verbose=True,
extra_cuda_cflags=[
"-O3",
Expand All @@ -26,6 +27,7 @@
]
)
warpscan_forward = module.warpscan_forward
warpscan_backward = module.warpscan_backward

def scan_forward(gates, tokens, reverse=False):
output = torch.zeros_like(tokens)
Expand Down Expand Up @@ -56,13 +58,10 @@ def backward(ctx, grad_output):
assert states.is_contiguous()
assert gates.is_contiguous()

padded_shifted_gates = torch.cat([gates, torch.ones_like(gates[:, :, :1])], dim=-1)[:, :, 1:].contiguous()
d_states = scan_forward(padded_shifted_gates, grad_output, reverse=True)
d_gates = torch.empty_like(gates)
d_tokens = torch.empty_like(gates)
warpscan_backward(gates, states, grad_output, d_gates, d_tokens)

padded_outputs = torch.cat([torch.zeros_like(states[:, :, :1]), states], dim=-1)[:, :, :-1]
d_gates = padded_outputs * d_states

d_tokens = d_states
return d_gates, d_tokens


Expand Down

0 comments on commit 0e12f07

Please sign in to comment.