From 036012ff2b93bc4135e2b7695b5233e210dcb0d7 Mon Sep 17 00:00:00 2001 From: Adam Mainz Date: Wed, 23 Oct 2024 14:22:52 -0700 Subject: [PATCH] adding fp32 strict and tf32x3 benchmarks for gemm Summary: TSIA draft diff while I move this to its own op Reviewed By: danzimm Differential Revision: D64781204 fbshipit-source-id: c3ddd956230c1e4c8166867f03b5a28e8d6586e9 --- torchbenchmark/operators/gemm/operator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchbenchmark/operators/gemm/operator.py b/torchbenchmark/operators/gemm/operator.py index 199608d64..23b5c65c3 100644 --- a/torchbenchmark/operators/gemm/operator.py +++ b/torchbenchmark/operators/gemm/operator.py @@ -184,6 +184,24 @@ def triton_ops_matmul(self, a, b, bias) -> Callable: return lambda: kernels.matmul(a, b) return lambda: kernels.matmul(a, b) + bias + @register_benchmark(enabled=False, ci=False) + def triton_ops_with_tf32x23(self, a, b, bias) -> Callable: + input_precision = None + if self.dtype == torch.float32: + input_precision = "tf32x3" + if bias is None: + return lambda: kernels.matmul(a, b, None, input_precision) + return lambda: kernels.matmul(a, b, None, input_precision) + bias + + @register_benchmark(enabled=False, ci=False) + def triton_ops_with_fp32_strict(self, a, b, bias) -> Callable: + input_precision = None + if self.dtype == torch.float32: + input_precision = "ieee" + if bias is None: + return lambda: kernels.matmul(a, b, self.dtype, input_precision) + return lambda: kernels.matmul(a, b, self.dtype, input_precision) + bias + @register_benchmark(baseline=True) def aten_matmul(self, a, b, bias) -> Callable: if not bias == None: