Skip to content

Commit

Permalink
[Operator] Add batch_norm (FlagOpen#362)
Browse files Browse the repository at this point in the history
* add batch_norm ops

* add batch_norm ops

* add batch_norm forward

* small fix

* add batch_norm ops

* add batch_norm ops

* add unit test

* add batch_norm ops

* add batch_norm perf

* add batch_norm perf

* add note

* add libentry

* fix rsqrt ci error

* update unit tests

* update unit test

* update perf and tune config

* Resolved a portion of the review comments

* resolve runtime call error

* resolve review suggestion
  • Loading branch information
2niuhe authored Jan 9, 2025
1 parent 8f84a44 commit ac17cd1
Show file tree
Hide file tree
Showing 7 changed files with 613 additions and 0 deletions.
25 changes: 25 additions & 0 deletions benchmark/test_norm_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ def instancenorm_input_fn(shape, dtype, device):
yield inp, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled


def batchnorm_input_fn(shape, dtype, device):
C = shape[1]
inp = torch.randn(shape, dtype=dtype, device=device)
weight = torch.randn((C,), dtype=dtype, device=device)
bias = torch.randn((C,), dtype=dtype, device=device)
running_mean = None
running_var = None
training = True
momentum = 0.1
eps = 1e-5
cudnn_enabled = True
yield inp, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled

if Config.bench_level == BenchLevel.COMPREHENSIVE:
running_mean = torch.randn((C,), dtype=dtype, device=device)
running_var = torch.randn((C,), dtype=dtype, device=device)
yield inp, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled


@pytest.mark.parametrize(
"op_name, torch_op, input_fn",
[
Expand All @@ -94,6 +113,12 @@ def instancenorm_input_fn(shape, dtype, device):
instancenorm_input_fn,
marks=pytest.mark.instance_norm,
),
pytest.param(
"batch_norm",
torch.batch_norm,
batchnorm_input_fn,
marks=pytest.mark.batch_norm,
),
],
)
def test_group_and_layer_and_instance_norm_benchmark(op_name, torch_op, input_fn):
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("arange.start_step", arange_start, Autograd.disable),
("arange.start", arange_start, Autograd.disable),
("arange", arange, Autograd.disable),
("batch_norm", batch_norm, Autograd.enable),
("bitwise_and.Tensor", bitwise_and_tensor, Autograd.disable),
("bitwise_and.Scalar", bitwise_and_scalar, Autograd.disable),
("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor, Autograd.disable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .arange import arange, arange_start
from .argmax import argmax
from .attention import scaled_dot_product_attention
from .batch_norm import batch_norm
from .bitwise_and import (
bitwise_and_scalar,
bitwise_and_scalar_tensor,
Expand Down Expand Up @@ -139,6 +140,7 @@
"addmm",
"arange",
"arange_start",
"batch_norm",
"bitwise_and_tensor",
"bitwise_and_scalar",
"bitwise_and_scalar_tensor",
Expand Down
Loading

0 comments on commit ac17cd1

Please sign in to comment.