Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for int8 quantization on linear layers #299

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 77 additions & 20 deletions src/kernl/implementations/linear_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_configs_io_bound():
}
)
@triton.jit
def kernel_fma(
def kernel_linear(
C, # Pointers to matrices
ACT_INPUTS,
A,
Expand Down Expand Up @@ -124,19 +124,56 @@ def kernel_fma(
HAS_BIAS: tl.constexpr,
SHOULD_SAVE_ACT_INPUTS: tl.constexpr,
ACTIVATION: tl.constexpr,
# quantization scalers
ALPHA_SCALER: tl.constexpr,
BETA_SCALER: tl.constexpr,
ACC_TYPE: tl.constexpr,
):
"""
Kernel for computing Out = activation(A x W + C)
Matrix multiplication kernel with fused activation and bias.
Out = activation(alpha * A x W + beta * C)

- Input has shape (M, K)
- Weight has shape (K, N)
- Bias has shape (N,)
- Bias has shape (N,) -> The bias is added to each row of the matmul output.
- Output has shape (M, N)
- ActInputs (optional) has shape (M, N)

'ActInputs' optionally saves the A x W + C intermediate for backward computations

This kernel will consolidate over K

:param C: Output matrix
:param ACT_INPUTS: (Optional) tensor to save the activation inputs (for backward)
:param A: Input matrix A (inputs)
:param B: Input matrix B (weights) (transposed)
:param bias: Bias vector
:param M: Number of rows in A and C
:param N: Number of columns in B and C
:param K: Number of columns in A and rows in B
:param CACHE_KEY_M: Cache key for M
:param CACHE_KEY_N: Cache key for N
:param CACHE_KEY_K: Cache key for K
:param output_m_stride: Stride for output matrix C
:param output_n_stride: Stride for output matrix C
:param act_inputs_m_stride: Stride for activation inputs matrix ACT_INPUTS
:param act_inputs_n_stride: Stride for activation inputs matrix ACT_INPUTS
:param a_m_stride: Stride for input matrix A
:param a_k_stride: Stride for input matrix A
:param b_n_stride: Stride for input matrix B
:param b_k_stride: Stride for input matrix B
:param BLOCK_M: Block size in the M dimension
:param GROUP_M: Number of blocks in the M dimension
:param BLOCK_N: Block size in the N dimension
:param BLOCK_K: Block size in the K dimension
:param SPLIT_K: Number of blocks in the K dimension
:param K_LOAD_MASK_NEEDED: Whether or not to use a mask when loading from B
:param HAS_BIAS: Whether or not to add a bias to the result
:param SHOULD_SAVE_ACT_INPUTS: Whether or not to save the activation inputs
:param ACTIVATION: Activation function to apply
:param ACC_TYPE: Accumulation type

:return: None
"""
program_idx = tl.program_id(axis=0)

Expand Down Expand Up @@ -164,11 +201,7 @@ def kernel_fma(
A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride)
B = B + (k_range_offs[:, None] * b_k_stride + n_offs[None, :] * b_n_stride)

acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

if HAS_BIAS:
bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0).to(tl.float32)
acc += bias[None, :]
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)

for k in range(K, 0, -BLOCK_K):
if K_LOAD_MASK_NEEDED:
Expand All @@ -182,6 +215,15 @@ def kernel_fma(
A += BLOCK_K * a_k_stride
B += BLOCK_K * b_k_stride

if ALPHA_SCALER != 1.0:
acc *= ALPHA_SCALER

if HAS_BIAS:
bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0) # .to(ACC_TYPE) # TODO fix when Triton updated
if BETA_SCALER != 1.0:
bias *= BETA_SCALER
acc += bias[None, :]

# optional: save the activation inputs
if SHOULD_SAVE_ACT_INPUTS:
act_in_ptrs = ACT_INPUTS + m_offs[:, None] * act_inputs_m_stride + n_offs[None, :] * act_inputs_n_stride
Expand Down Expand Up @@ -212,7 +254,10 @@ def forward(
weight: torch.Tensor,
bias: Optional[torch.Tensor],
activation: str,
alpha_scaler: float,
beta_scaler: float,
act_inputs: Optional[torch.Tensor],
output: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Compute e = activation(x @ weight + bias).
Expand All @@ -221,16 +266,19 @@ def forward(
:param x: input tensor
:param weight: weight matrix
:param bias: an optional bias tensor
:param activation: Activation name. Needs to be a Triton kernel.
:param activation: Activation name (relu, tanh, gelu, fast_gelu)
:param act_inputs: an optional tensor to save the activation inputs (for backward)
:param alpha_scaler: alpha scaler (to be appled on mamtul output)
:param beta_scaler: beta scaler (to be applied on bias)
:param output: an optional output tensor
:return: result tensor
"""
x_ = x if x.ndim == 2 else x.flatten(0, 1)

assert x.dtype == weight.dtype, f"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}"
if bias is not None:
assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} - {weight.shape}"
# if bias is not None:
# assert x.dtype == bias.dtype, f"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}"
assert x_.shape[1] == weight.shape[1], f"Incompatible dimensions: {x_.shape} / {weight.shape}"

assert bias is None or bias.is_contiguous()
assert bias is None or bias.shape[0] == weight.shape[0], "Incompatible dimensions in between weight and bias"
Expand All @@ -239,13 +287,16 @@ def forward(
M, K = x_.shape
N, K = weight.shape

outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)

if output is None:
output = torch.empty((M, N), device=x.device, dtype=x.dtype)
else:
output = output if output.ndim == 2 else output.flatten(0, 1)
acc_type = tl.float32 if output.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa

kernel_fma[grid](
outputs,
kernel_linear[grid](
output,
act_inputs,
x_,
weight, # data ptrs
Expand All @@ -256,8 +307,8 @@ def forward(
M // 32, # key for triton cache (limit number of compilations)
N // 32,
K // 32,
output_m_stride=outputs.stride(0), # strides
output_n_stride=outputs.stride(1),
output_m_stride=output.stride(0), # strides
output_n_stride=output.stride(1),
act_inputs_m_stride=act_inputs.stride(0) if act_inputs is not None else 0,
act_inputs_n_stride=act_inputs.stride(1) if act_inputs is not None else 0,
a_m_stride=x_.stride(0),
Expand All @@ -267,10 +318,13 @@ def forward(
HAS_BIAS=bias is not None, # optional fused bias
SHOULD_SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs
ACTIVATION=activation if not None else x, # optional fused activation
ALPHA_SCALER=alpha_scaler, # optional alpha scaler (quantization)
BETA_SCALER=beta_scaler, # optional beta scaler (quantization)
ACC_TYPE=acc_type, # accumulator type
GROUP_M=8, # speed optimization: group the programs
)

outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N)
outputs = output if x.ndim == 2 else output.reshape(x.shape[0], -1, N)
ctx.save_for_backward(weight, bias, x)
return outputs

Expand All @@ -281,5 +335,8 @@ def linear_layer(
bias: Optional[torch.Tensor],
activation="",
act_inputs: Optional[torch.Tensor] = None,
alpha_scaler=1.0,
beta_scaler=1.0,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return LinearLayer.apply(x, weight, bias, activation, act_inputs)
return LinearLayer.apply(x, weight, bias, activation, alpha_scaler, beta_scaler, act_inputs, output)
200 changes: 200 additions & 0 deletions src/kernl/implementations/linear_layer_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright 2022 Lefebvre Sarrut
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# code inspired from torch-int pacakge
# https://github.com/Guangxuan-Xiao/torch-int/blob/main/torch_int/nn/linear.py

import torch

from kernl.implementations.linear_layer import linear_layer


@torch.no_grad()
def quantize_per_tensor_absmax(t: torch.Tensor):
scale = t.abs().max() / 127
if not t.is_cuda:
# half rounding is not supported on CPU
t = t.float()
# use inplace operation to save memory
t.div_(scale).round_()
t_q = t.to(torch.int8)
return t_q, scale


class W8A8B8O8Linear(torch.nn.Module):
# For qkv_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False))
self.register_buffer("a", torch.tensor(alpha))
self.register_buffer("b", torch.tensor(beta))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.int8)
linear_layer(
x=x,
weight=self.weight,
bias=self.bias,
activation="",
act_inputs=None,
alpha_scaler=self.a.item(),
beta_scaler=self.b.item(),
output=y,
)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale, output_scale):
int8_module = W8A8B8O8Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
alpha = input_scale * weight_scale / output_scale
beta = bias_scale / output_scale
int8_module.weight = int8_weight
int8_module.bias = int8_bias
int8_module.a = alpha
int8_module.b = beta
return int8_module


class W8A8B8O8LinearReLU(torch.nn.Module):
# For fc1
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.int8, requires_grad=False))
self.register_buffer("a", torch.tensor(alpha))
self.register_buffer("b", torch.tensor(beta))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.int8)

linear_layer(
x=x,
weight=self.weight,
bias=self.bias,
activation="relu",
act_inputs=None,
alpha_scaler=self.a.item(),
beta_scaler=self.b.item(),
output=y,
)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale, output_scale):
# TODO: add zero-point to prevent the bit waste
int8_module = W8A8B8O8LinearReLU(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
int8_bias, bias_scale = quantize_per_tensor_absmax(module.bias)
alpha = input_scale * weight_scale / output_scale
beta = bias_scale / output_scale
int8_module.weight = int8_weight
int8_module.bias = int8_bias
int8_module.a = alpha
int8_module.b = beta
return int8_module


class W8A8BFP32OFP32Linear(torch.nn.Module):
# For fc2 and out_proj
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features

self.register_buffer(
"weight",
torch.randint(-127, 127, (self.out_features, self.in_features), dtype=torch.int8, requires_grad=False),
)
self.register_buffer("bias", torch.zeros((1, self.out_features), dtype=torch.float32, requires_grad=False))
self.register_buffer("a", torch.tensor(alpha))

def _apply(self, fn):
# prevent the bias from being converted to half
super()._apply(fn)
self.bias = self.bias.to(torch.float32)
return self

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
self.bias = self.bias.to(*args, **kwargs)
self.bias = self.bias.to(torch.float32)
return self

@torch.no_grad()
def forward(self, x):
x_shape = x.shape
x = x.view(-1, x_shape[-1])
self.bias = self.bias.to(torch.float32)
y = torch.empty((x.shape[0], self.weight.shape[0]), device=x.device, dtype=torch.float32)

linear_layer(
x=x,
weight=self.weight,
bias=self.bias,
activation="",
act_inputs=None,
alpha_scaler=self.a.item(),
beta_scaler=1.0,
output=y,
)
y = y.view(*x_shape[:-1], -1)
return y

@staticmethod
def from_float(module: torch.nn.Linear, input_scale):
int8_module = W8A8BFP32OFP32Linear(module.in_features, module.out_features)
int8_weight, weight_scale = quantize_per_tensor_absmax(module.weight)
alpha = input_scale * weight_scale
int8_module.weight = int8_weight
int8_module.bias = module.bias.to(torch.float32)
int8_module.a = alpha
int8_module.input_scale = input_scale
int8_module.weight_scale = weight_scale
return int8_module
Loading