Skip to content

Commit

Permalink
[Inductor] Enable CppWrapper to support BF16 (pytorch#97089)
Browse files Browse the repository at this point in the history
  • Loading branch information
EikanWang authored and pytorchmergebot committed Mar 22, 2023
1 parent 573b2de commit 517a432
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 31 deletions.
36 changes: 21 additions & 15 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6125,15 +6125,17 @@ def fn(value, mask):
value = torch.randn((2, 17), dtype=dtype)
mask = torch.randint(0, 1, size=(2, 17), dtype=torch.uint8)
with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(value, mask)

real_out = fn(value, mask)
compiled_out = opt_fn(value, mask)
assert same(real_out, compiled_out, equal_nan=True)
assert metrics.generated_cpp_vec_kernel_count >= 1
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
torch._dynamo.reset()
metrics.reset()
opt_fn = torch._dynamo.optimize("inductor")(fn)
opt_fn(value, mask)

real_out = fn(value, mask)
compiled_out = opt_fn(value, mask)
assert same(real_out, compiled_out, equal_nan=True)
assert metrics.generated_cpp_vec_kernel_count >= 1

def test_load_same_bool_tensor_twice(self):
@torch._dynamo.optimize("inductor")
Expand Down Expand Up @@ -6229,12 +6231,16 @@ def fn(x):
tol = 1e-2 if dtype == torch.bfloat16 else 1e-4

with config.patch({"cpp.simdlen": None}):
torch._dynamo.reset()
metrics.reset()
traced = make_fx(fn)(x)
compiled = compile_fx_inner(traced, [x])
assert same(fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol)
assert metrics.generated_cpp_vec_kernel_count == 1
for cpp_wrapper_flag in [True, False]:
with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
torch._dynamo.reset()
metrics.reset()
traced = make_fx(fn)(x)
compiled = compile_fx_inner(traced, [x])
assert same(
fn(x)[0], compiled([x])[0], equal_nan=True, tol=tol
)
assert metrics.generated_cpp_vec_kernel_count == 1

@unittest.skipIf(
not codecache.valid_vec_isa_list(), "Does not support vectorization"
Expand Down
24 changes: 9 additions & 15 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,8 @@ def write_prefix(self):
#include <dlfcn.h>
#include <assert.h>
typedef at::BFloat16 bfloat16;
template <typename KernelFunc>
KernelFunc load_cpp_kernel(const char* so_filename) {
KernelFunc kernel_cpp;
Expand All @@ -720,10 +722,7 @@ def write_wrapper_decl(self):
inputs_len = len(V.graph.graph_inputs.keys())
output_refs = self.get_output_refs()
if output_refs:
if len(output_refs) == 1:
output_types = "at::Tensor"
else:
output_types = "std::vector<at::Tensor>"
output_types = "std::vector<at::Tensor>"
else:
output_types = "void"

Expand Down Expand Up @@ -785,17 +784,12 @@ def return_end_str(self):

def generate_return(self, output_refs):
if output_refs:
if len(output_refs) == 1:
self.wrapper_call.writeline(
f"return {output_refs[0]};{self.return_end_str()}"
)
else:
self.wrapper_call.writeline(
"return std::vector<at::Tensor>({"
+ ", ".join(output_refs)
+ "});"
+ self.return_end_str()
)
self.wrapper_call.writeline(
"return std::vector<at::Tensor>({"
+ ", ".join(output_refs)
+ "});"
+ self.return_end_str()
)
else:
self.wrapper_call.writeline(f"return;{self.return_end_str()}")

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def supported_dtype_of_cpp_wrapper(dtype):
torch.int8,
torch.uint8,
torch.bool,
torch.bfloat16,
# torch.float16, # TODO: implement this
# torch.bfloat16, # TODO: implement this
}
return dtype in supported_dtype

Expand Down

0 comments on commit 517a432

Please sign in to comment.