diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index b12355db110109..864b4a97c2d00a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6372,14 +6372,19 @@ def fn(x, y): x = torch.randn((2, 9), dtype=torch.bfloat16) y = torch.randn((2, 9), dtype=torch.bfloat16) - with config.patch({"cpp.simdlen": None}): - torch._dynamo.reset() - metrics.reset() - traced = make_fx(fn)(x, y) - compiled = compile_fx_inner(traced, [x, y]) - assert same(fn(x, y)[0], compiled([x, y])[0], equal_nan=True, tol=1e-2) - if codecache.valid_vec_isa_list(): - assert metrics.generated_cpp_vec_kernel_count == 1 + for torch_compile_debug in [True, False]: + with config.patch( + {"trace.enabled": torch_compile_debug, "cpp.simdlen": None} + ): + torch._dynamo.reset() + metrics.reset() + traced = make_fx(fn)(x, y) + compiled = compile_fx_inner(traced, [x, y]) + assert same( + fn(x, y)[0], compiled([x, y])[0], equal_nan=True, tol=1e-2 + ) + if codecache.valid_vec_isa_list(): + assert metrics.generated_cpp_vec_kernel_count == 1 @unittest.skipIf( not codecache.valid_vec_isa_list(), "Does not support vectorization"