From 517a432d6e2feea29b08e8dca1b223c61c96ac00 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Tue, 21 Mar 2023 14:18:21 +0000 Subject: [PATCH] [Inductor] Enable CppWrapper to support BF16 (#97089) Pull Request resolved: https://github.com/pytorch/pytorch/pull/97089 Approved by: https://github.com/jgong5, https://github.com/jansel --- test/inductor/test_torchinductor.py | 36 +++++++++++++++++------------ torch/_inductor/codegen/wrapper.py | 24 ++++++++----------- torch/_inductor/graph.py | 2 +- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 864b4a97c2d00..1147942257de1 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6125,15 +6125,17 @@ def fn(value, mask): value = torch.randn((2, 17), dtype=dtype) mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8) with config.patch({"cpp.simdlen": None}): - torch._dynamo.reset() - metrics.reset() - opt_fn = torch._dynamo.optimize("inductor")(fn) - opt_fn(value, mask) - - real_out = fn(value, mask) - compiled_out = opt_fn(value, mask) - assert same(real_out, compiled_out, equal_nan=True) - assert metrics.generated_cpp_vec_kernel_count >= 1 + for cpp_wrapper_flag in [True, False]: + with config.patch({"cpp_wrapper": cpp_wrapper_flag}): + torch._dynamo.reset() + metrics.reset() + opt_fn = torch._dynamo.optimize("inductor")(fn) + opt_fn(value, mask) + + real_out = fn(value, mask) + compiled_out = opt_fn(value, mask) + assert same(real_out, compiled_out, equal_nan=True) + assert metrics.generated_cpp_vec_kernel_count >= 1 def test_load_same_bool_tensor_twice(self): @torch._dynamo.optimize("inductor") @@ -6229,12 +6231,16 @@ def fn(x): tol = 1e-2 if dtype == torch.bfloat16 else 1e-4 with config.patch({"cpp.simdlen": None}): - torch._dynamo.reset() - metrics.reset() - traced = make_fx(fn)(x) - compiled = compile_fx_inner(traced, [x]) - assert same(fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol) - assert metrics.generated_cpp_vec_kernel_count == 1 + for cpp_wrapper_flag in [True, False]: + with config.patch({"cpp_wrapper": cpp_wrapper_flag}): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x) + compiled = compile_fx_inner(traced, [x]) + assert same( + fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol + ) + assert metrics.generated_cpp_vec_kernel_count == 1 @unittest.skipIf( not codecache.valid_vec_isa_list(), "Does not support vectorization" diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 776364f4fe9d8..dec9f62500804 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -704,6 +704,8 @@ def write_prefix(self): #include #include + typedef at::BFloat16 bfloat16; + template KernelFunc load_cpp_kernel(const char* so_filename) { KernelFunc kernel_cpp; @@ -720,10 +722,7 @@ def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) output_refs = self.get_output_refs() if output_refs: - if len(output_refs) == 1: - output_types = "at::Tensor" - else: - output_types = "std::vector" + output_types = "std::vector" else: output_types = "void" @@ -785,17 +784,12 @@ def return_end_str(self): def generate_return(self, output_refs): if output_refs: - if len(output_refs) == 1: - self.wrapper_call.writeline( - f"return {output_refs[0]};{self.return_end_str()}" - ) - else: - self.wrapper_call.writeline( - "return std::vector({" - + ", ".join(output_refs) - + "});" - + self.return_end_str() - ) + self.wrapper_call.writeline( + "return std::vector({" + + ", ".join(output_refs) + + "});" + + self.return_end_str() + ) else: self.wrapper_call.writeline(f"return;{self.return_end_str()}") diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 7f4563f00a6ca..737bf42855a84 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -61,8 +61,8 @@ def supported_dtype_of_cpp_wrapper(dtype): torch.int8, torch.uint8, torch.bool, + torch.bfloat16, # torch.float16, # TODO: implement this - # torch.bfloat16, # TODO: implement this } return dtype in supported_dtype