diff --git a/benchmark/test_norm_perf.py b/benchmark/test_norm_perf.py index 5722ce37..197988b4 100644 --- a/benchmark/test_norm_perf.py +++ b/benchmark/test_norm_perf.py @@ -73,6 +73,25 @@ def instancenorm_input_fn(shape, dtype, device): yield inp, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled +def batchnorm_input_fn(shape, dtype, device): + C = shape[1] + inp = torch.randn(shape, dtype=dtype, device=device) + weight = torch.randn((C,), dtype=dtype, device=device) + bias = torch.randn((C,), dtype=dtype, device=device) + running_mean = None + running_var = None + training = True + momentum = 0.1 + eps = 1e-5 + cudnn_enabled = True + yield inp, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled + + if Config.bench_level == BenchLevel.COMPREHENSIVE: + running_mean = torch.randn((C,), dtype=dtype, device=device) + running_var = torch.randn((C,), dtype=dtype, device=device) + yield inp, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled + + @pytest.mark.parametrize( "op_name, torch_op, input_fn", [ @@ -94,6 +113,12 @@ def instancenorm_input_fn(shape, dtype, device): instancenorm_input_fn, marks=pytest.mark.instance_norm, ), + pytest.param( + "batch_norm", + torch.batch_norm, + batchnorm_input_fn, + marks=pytest.mark.batch_norm, + ), ], ) def test_group_and_layer_and_instance_norm_benchmark(op_name, torch_op, input_fn): diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index f2ec69f4..08a0ac97 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -24,6 +24,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar): ("arange.start_step", arange_start, Autograd.disable), ("arange.start", arange_start, Autograd.disable), ("arange", arange, Autograd.disable), + ("batch_norm", batch_norm, Autograd.enable), ("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable), ("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable), ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index c8b01414..73a5ac26 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -7,6 +7,7 @@ from .arange import arange, arange_start from .argmax import argmax from .attention import scaled_dot_product_attention +from .batch_norm import batch_norm from .bitwise_and import ( bitwise_and_scalar, bitwise_and_scalar_tensor, @@ -139,6 +140,7 @@ "addmm", "arange", "arange_start", + "batch_norm", "bitwise_and_tensor", "bitwise_and_scalar", "bitwise_and_scalar_tensor", diff --git a/src/flag_gems/ops/batch_norm.py b/src/flag_gems/ops/batch_norm.py new file mode 100644 index 00000000..3c975a53 --- /dev/null +++ b/src/flag_gems/ops/batch_norm.py @@ -0,0 +1,457 @@ +import logging + +import torch +import triton +import triton.language as tl +from torch import Tensor + +from .. import runtime +from ..runtime import torch_device_fn +from ..utils import libentry, tl_extra_shim +from ..utils.type_utils import get_accumulator_dtype + +rsqrt = tl_extra_shim.rsqrt + + +def make_3d_for_bn(input: Tensor) -> Tensor: + """ + Converts the input to a 3D view for batch normalization. + + Args: + input: Input to render 3D. + + Returns: + Input's 3D view. + """ + if input.ndim == 2: + input = input.unsqueeze(-1) + + elif input.ndim >= 4: + input = input.flatten(2, -1) + + return input + + +# NOTE: This part of the kernel code is copied and modified +# from the https://github.com/BobMcDear/attorch codebase. + + +@libentry() +@triton.autotune( + configs=runtime.get_tuned_config("batch_norm"), + key=["batch_dim", "spatial_dim"], + restore_value=["running_mean_pointer", "running_var_pointer"], +) +@triton.heuristics(runtime.get_heuristic_config("batch_norm")) +@triton.jit +def batch_norm_forward_kernel( + input_pointer, + weight_pointer, + bias_pointer, + mean_pointer, + inv_std_pointer, + output_pointer, + running_mean_pointer, + running_var_pointer, + batch_dim, + spatial_dim, + input_batch_stride, + input_feat_stride, + input_spatial_stride, + output_batch_stride, + output_feat_stride, + output_spatial_stride, + momentum, + eps, + affine: tl.constexpr, + save_stats: tl.constexpr, + is_train: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + feat_pid = tl.program_id(axis=0) + + # traning mode default track_running_stat + if is_train: + mean = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + var = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + cnt = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32) + + m_num_steps = tl.cdiv(batch_dim, BLOCK_M) + n_num_steps = tl.cdiv(spatial_dim, BLOCK_N) + + for m_step in range(0, m_num_steps): + for n_step in range(0, n_num_steps): + spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) + spatial_mask = spatial_offset < spatial_dim + + batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) + batch_mask = batch_offset < batch_dim + + curr_input_pointer = ( + input_pointer + + input_feat_stride * feat_pid + + input_batch_stride * batch_offset[:, None] + + input_spatial_stride * spatial_offset[None, :] + ) + + mask = batch_mask[:, None] & spatial_mask[None, :] + curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) + + step = m_step * n_num_steps + n_step + 1 + new_mean = tl.where(mask, mean + (curr_input - mean) / step, mean) + new_var = tl.where( + mask, var + (curr_input - new_mean) * (curr_input - mean), var + ) + cnt += mask.to(tl.int32) + mean = new_mean + var = new_var + + final_mean = tl.sum(mean * cnt) / (batch_dim * spatial_dim) + var = tl.sum(var + cnt * (mean - final_mean) * (mean - final_mean)) / ( + batch_dim * spatial_dim + ) + inv_std = rsqrt(var + eps) + mean = final_mean + + if save_stats: + tl.store(feat_pid + mean_pointer, mean) + tl.store(feat_pid + inv_std_pointer, inv_std) + + running_mean_pointer += feat_pid + running_var_pointer += feat_pid + + running_mean = tl.load(running_mean_pointer) + running_var = tl.load(running_var_pointer) + + n = batch_dim * spatial_dim + tl.store(running_mean_pointer, (1 - momentum) * running_mean + momentum * mean) + tl.store( + running_var_pointer, + (1 - momentum) * running_var + momentum * var * n / (n - 1), + ) + + else: + mean = tl.load(feat_pid + running_mean_pointer) + inv_std = rsqrt(tl.load(feat_pid + running_var_pointer) + eps) + + if affine: + weight = tl.load(feat_pid + weight_pointer) + bias = tl.load(feat_pid + bias_pointer) + + else: + weight = 1.0 + bias = 0.0 + + for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): + for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): + batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) + batch_mask = batch_offset < batch_dim + + spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) + spatial_mask = spatial_offset < spatial_dim + + curr_input_pointer = ( + input_pointer + + input_feat_stride * feat_pid + + input_batch_stride * batch_offset[:, None] + + input_spatial_stride * spatial_offset[None, :] + ) + curr_output_pointer = ( + output_pointer + + output_feat_stride * feat_pid + + output_batch_stride * batch_offset[:, None] + + output_spatial_stride * spatial_offset[None, :] + ) + + curr_input = tl.load( + curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] + ).to(tl.float32) + output = weight * (curr_input - mean) * inv_std + bias + + tl.store( + curr_output_pointer, + output, + mask=batch_mask[:, None] & spatial_mask[None, :], + ) + + +@libentry() +@triton.autotune( + configs=runtime.get_tuned_config("batch_norm"), + key=["batch_dim", "spatial_dim"], +) +@triton.heuristics(runtime.get_heuristic_config("batch_norm")) +@triton.jit +def batch_norm_backward_kernel( + output_grad_pointer, + input_pointer, + mean_pointer, + inv_std_pointer, + weight_pointer, + input_grad_pointer, + weight_grad_pointer, + bias_grad_pointer, + batch_dim, + spatial_dim, + output_grad_batch_stride, + output_grad_feat_stride, + output_grad_spatial_stride, + input_batch_stride, + input_feat_stride, + input_spatial_stride, + input_grad_batch_stride, + input_grad_feat_stride, + input_grad_spatial_stride, + affine: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + feat_pid = tl.program_id(axis=0) + + mean = tl.load(feat_pid + mean_pointer).to(tl.float32) + inv_std = tl.load(feat_pid + inv_std_pointer).to(tl.float32) + + term1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + term2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): + for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): + batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) + batch_mask = batch_offset < batch_dim + + spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) + spatial_mask = spatial_offset < spatial_dim + + curr_output_grad_pointer = ( + output_grad_pointer + + output_grad_feat_stride * feat_pid + + output_grad_batch_stride * batch_offset[:, None] + + output_grad_spatial_stride * spatial_offset[None, :] + ) + curr_input_pointer = ( + input_pointer + + input_feat_stride * feat_pid + + input_batch_stride * batch_offset[:, None] + + input_spatial_stride * spatial_offset[None, :] + ) + + mask = batch_mask[:, None] & spatial_mask[None, :] + curr_input = tl.load(curr_input_pointer, mask=mask).to(tl.float32) + + curr_pre_lin = (curr_input - mean) * inv_std + curr_output_grad = tl.load(curr_output_grad_pointer, mask=mask).to( + tl.float32 + ) + + term1 += curr_pre_lin * curr_output_grad + term2 += curr_output_grad + + term1 = tl.sum(term1) + term2 = tl.sum(term2) + + if affine: + weight = tl.load(feat_pid + weight_pointer) + weight_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + bias_grad_acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + + else: + weight = 1.0 + + count = batch_dim * spatial_dim + + for m_step in range(0, tl.cdiv(batch_dim, BLOCK_M)): + for n_step in range(0, tl.cdiv(spatial_dim, BLOCK_N)): + batch_offset = m_step * BLOCK_M + tl.arange(0, BLOCK_M) + batch_mask = batch_offset < batch_dim + + spatial_offset = n_step * BLOCK_N + tl.arange(0, BLOCK_N) + spatial_mask = spatial_offset < spatial_dim + + curr_output_grad_pointer = ( + output_grad_pointer + + output_grad_feat_stride * feat_pid + + output_grad_batch_stride * batch_offset[:, None] + + output_grad_spatial_stride * spatial_offset[None, :] + ) + curr_input_pointer = ( + input_pointer + + input_feat_stride * feat_pid + + input_batch_stride * batch_offset[:, None] + + input_spatial_stride * spatial_offset[None, :] + ) + curr_input_grad_pointer = ( + input_grad_pointer + + input_grad_feat_stride * feat_pid + + input_grad_batch_stride * batch_offset[:, None] + + input_grad_spatial_stride * spatial_offset[None, :] + ) + + curr_input = tl.load( + curr_input_pointer, mask=batch_mask[:, None] & spatial_mask[None, :] + ).to(tl.float32) + curr_pre_lin = (curr_input - mean) * inv_std + curr_output_grad = tl.load( + curr_output_grad_pointer, + mask=batch_mask[:, None] & spatial_mask[None, :], + ).to(tl.float32) + curr_input_grad = ( + inv_std + * weight + * (curr_output_grad - (term1 * curr_pre_lin + term2) / count) + ) + tl.store( + curr_input_grad_pointer, + curr_input_grad, + mask=batch_mask[:, None] & spatial_mask[None, :], + ) + + if affine: + weight_grad_acc += curr_pre_lin * curr_output_grad + bias_grad_acc += curr_output_grad + + if affine: + tl.store(feat_pid + weight_grad_pointer, tl.sum(weight_grad_acc)) + tl.store(feat_pid + bias_grad_pointer, tl.sum(bias_grad_acc)) + + +class BatchNorm(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input: Tensor, + weight=None, + bias=None, + running_mean=None, # self.running_mean if not self.training or self.track_running_state else None + running_var=None, + training=False, # (self.running_mean is None) and (self.running_var is None) + momentum=0.1, + eps=1e-05, + cudnn_enable=True, + ): + logging.debug("GEMS BATCHNORM FORWARD") + + input_3d = make_3d_for_bn(input) + + affine = weight is not None and bias is not None + requires_grad = ( + input.requires_grad + or (affine and weight.requires_grad) + or (affine and bias.requires_grad) + ) + + batch_dim, feat_dim, spatial_dim = input_3d.shape + output = torch.empty_like(input_3d) + + if requires_grad: + acc_type = get_accumulator_dtype(input.dtype) + mean = torch.empty(feat_dim, device=input.device, dtype=acc_type) + inv_std = torch.empty(feat_dim, device=input.device, dtype=acc_type) + + else: + mean = inv_std = None + + running_mean = input if running_mean is None else running_mean + running_var = input if running_var is None else running_var + + # Launches 1D grid where each program operates over one feature. + with torch_device_fn.device(input.device): + batch_norm_forward_kernel[(feat_dim,)]( + input_3d, + weight, + bias, + mean, + inv_std, + output, + running_mean, + running_var, + batch_dim, + spatial_dim, + *input_3d.stride(), + *output.stride(), + momentum, + eps, + affine=affine, + save_stats=requires_grad, + is_train=training, + ) + + ctx.affine = affine + if requires_grad: + ctx.save_for_backward(input, mean, inv_std, weight) + + return output.view_as(input) + + @staticmethod + def backward(ctx, output_grad): + logging.debug("GEMS BATCHNORM BACKWARD") + (input, mean, inv_std, weight) = ctx.saved_tensors + input_3d = make_3d_for_bn(input) + output_grad_3d = make_3d_for_bn(output_grad) + + batch_dim, feat_dim, spatial_dim = input_3d.shape + input_grad = torch.empty_like(input_3d) + + if ctx.affine: + weight_grad = torch.empty((feat_dim,), device=input.device) + bias_grad = torch.empty_like(weight_grad) + + else: + weight_grad = bias_grad = None + + # Launches 1D grid where each program operates over one feature. + with torch_device_fn.device(input.device): + batch_norm_backward_kernel[(feat_dim,)]( + output_grad_3d, + input_3d, + mean, + inv_std, + weight, + input_grad, + weight_grad, + bias_grad, + batch_dim, + spatial_dim, + *output_grad_3d.stride(), + *input_3d.stride(), + *input_grad.stride(), + affine=ctx.affine, + ) + + # Pads output with None because a gradient is necessary for + # all input arguments. + return ( + input_grad.view_as(input), + weight_grad, + bias_grad, + None, + None, + None, + None, + None, + None, + ) + + +def batch_norm( + input, + weight=None, + bias=None, + running_mean=None, + running_var=None, + training=False, + momentum=0.1, + eps=1e-05, + cudnn_enable=True, +): + return BatchNorm.apply( + input, + weight, + bias, + running_mean, + running_var, + training, + momentum, + eps, + cudnn_enable, + ) diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index f3a9c482..2c3e3c14 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -195,6 +195,17 @@ def upsample_nearest2d_SAME_W(args): return args["OW"] == args["IW"] +def batch_norm_heur_block_m(args): + return min(2048, triton.next_power_of_2(args["batch_dim"])) + + +def batch_norm_heur_block_n(args): + # A maximum of 16384 elements are loaded at once. + BLOCK_M = batch_norm_heur_block_m(args) + BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) + return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) + + HEURISTICS_CONFIGS = { "argmax": { "BLOCK_M": argmax_heur_block_m, @@ -262,4 +273,8 @@ def upsample_nearest2d_SAME_W(args): "var_mean": { "BLOCK_N": var_mean_heur_block_n, }, + "batch_norm": { + "BLOCK_M": batch_norm_heur_block_m, + "BLOCK_N": batch_norm_heur_block_n, + }, } diff --git a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml index 6293316a..8e657a03 100644 --- a/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml +++ b/src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml @@ -950,3 +950,15 @@ var_mean: - 4 - 8 - 16 +batch_norm: +- gen: true + param_map: + META: {} + num_warps: warps + warps: + - 1 + - 2 + - 4 + - 8 + - 16 + - 32 diff --git a/tests/test_norm_ops.py b/tests/test_norm_ops.py index 6d79bfc6..b1992cf1 100644 --- a/tests/test_norm_ops.py +++ b/tests/test_norm_ops.py @@ -1,3 +1,5 @@ +import math + import pytest import torch @@ -469,3 +471,102 @@ def test_accuracy_vectornorm(shape, ord, dim, keepdim, dtype): res_out = torch.linalg.vector_norm(inp, ord, dim, keepdim) gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.batch_norm +@pytest.mark.parametrize( + "shape", + [ + (16, 3), + (32, 32, 32), + (8, 32, 224, 224), + (2050, 16, 32, 32), + (8, 16, 3, 224, 224), + ], +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("affine", [True, False]) +@pytest.mark.parametrize("require_grad", [True, False]) +def test_accuracy_batch_norm(shape, dtype, affine, require_grad): + C = shape[1] + inp = torch.randn( + size=shape, dtype=dtype, device=flag_gems.device, requires_grad=require_grad + ) + weight = ( + torch.randn( + size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=require_grad + ) + if affine + else None + ) + bias = ( + torch.randn( + size=(C,), dtype=dtype, device=flag_gems.device, requires_grad=require_grad + ) + if affine + else None + ) + + running_mean = torch.zeros(size=(C,), dtype=dtype, device=flag_gems.device) + running_var = torch.ones(size=(C,), dtype=dtype, device=flag_gems.device) + + eps = 1e-5 + + ref_inp = to_reference(inp, True) + ref_weight = to_reference(weight, True) + ref_bias = to_reference(bias, True) + ref_running_mean = to_reference(running_mean, True) + ref_running_var = to_reference(running_var, True) + + training = require_grad + + ref_out = torch.nn.functional.batch_norm( + ref_inp, + ref_running_mean, + ref_running_var, + weight=ref_weight, + bias=ref_bias, + training=training, + eps=eps, + ) + + with flag_gems.use_gems(): + res_out = torch.nn.functional.batch_norm( + inp, + running_mean, + running_var, + weight=weight, + bias=bias, + training=training, + eps=eps, + ) + + gems_assert_close(res_out, ref_out, dtype) + gems_assert_close(running_mean, ref_running_mean, dtype) + gems_assert_close(running_var, ref_running_var, dtype) + + if not require_grad: + return + + out_grad = torch.randn_like(inp) + ref_grad = to_reference(out_grad, True) + reduce_dim = int(math.prod(shape) / C) + + if affine: + (ref_in_grad, ref_weight_grad, ref_bias_grad) = torch.autograd.grad( + ref_out, (ref_inp, ref_weight, ref_bias), ref_grad + ) + (res_in_grad, res_weight_grad, res_bias_grad) = torch.autograd.grad( + res_out, (inp, weight, bias), out_grad + ) + + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=reduce_dim) + gems_assert_close( + res_weight_grad, ref_weight_grad, dtype, reduce_dim=reduce_dim + ) + gems_assert_close(res_bias_grad, ref_bias_grad, dtype, reduce_dim=reduce_dim) + else: + (ref_in_grad,) = torch.autograd.grad(ref_out, (ref_inp,), ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, (inp,), out_grad) + + gems_assert_close(res_in_grad, ref_in_grad, dtype, reduce_dim=reduce_dim)