From e755b1e580f6ecadb50d53dafd410d1cfed694d1 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Mon, 25 Mar 2024 19:34:28 +0800 Subject: [PATCH] Align with PyTorch test_ops case suites (#50) Signed-off-by: Feng Yuan Co-authored-by: Yutao Xu Co-authored-by: chuanqiw --- .github/ci_commit_pins/pytorch.txt | 1 - .github/workflows/pull.yml | 9 +- .../examples => examples}/test_binary.py | 6 +- .../examples => examples}/test_compare.py | 12 +- .../python/examples => examples}/test_copy.py | 0 .../examples => examples}/test_loops.py | 6 +- .../python/examples => examples}/test_rand.py | 2 +- .../examples => examples}/test_resize.py | 4 +- .../test_tensor_factory.py | 0 .../examples => examples}/test_unary.py | 18 +- src/aten/BinaryOps.cpp | 8 +- test/xpu/test_ops.py | 391 ++++++++++++++++++ 12 files changed, 424 insertions(+), 33 deletions(-) delete mode 100644 .github/ci_commit_pins/pytorch.txt rename {test/python/examples => examples}/test_binary.py (99%) rename {test/python/examples => examples}/test_compare.py (97%) rename {test/python/examples => examples}/test_copy.py (100%) rename {test/python/examples => examples}/test_loops.py (99%) rename {test/python/examples => examples}/test_rand.py (99%) rename {test/python/examples => examples}/test_resize.py (99%) rename {test/python/examples => examples}/test_tensor_factory.py (100%) rename {test/python/examples => examples}/test_unary.py (98%) create mode 100644 test/xpu/test_ops.py diff --git a/.github/ci_commit_pins/pytorch.txt b/.github/ci_commit_pins/pytorch.txt deleted file mode 100644 index 5f1f5db0c..000000000 --- a/.github/ci_commit_pins/pytorch.txt +++ /dev/null @@ -1 +0,0 @@ -3600778edeb1ac4eefe36cedb6facf02c58ce0d4 diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 8b0b41112..c33c62db9 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -26,8 +26,8 @@ jobs: run: | pwd cd ../ && rm -rf pytorch - git clone https://github.com/pytorch/pytorch - cd pytorch && git checkout `cat ../torch-xpu-ops/.github/ci_commit_pins/pytorch.txt` && git submodule sync && git submodule update --init --recursive + git clone -b nightly https://github.com/pytorch/pytorch + cd pytorch && git log -n 1 && git submodule sync && git submodule update --init --recursive rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/ # Workaround for torch-xpu-ops ci test sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt @@ -52,11 +52,12 @@ jobs: cd examples timeout 8000 pytest -v 2>&1 | tee torch_xpu_ops_example.log - name: Run XPU OP UT - if: ${{ hashFiles('test/python/') != '' }} + if: ${{ hashFiles('test/xpu/') != '' }} run: | source /opt/intel/oneapi/compiler/latest/env/vars.sh source activate xpu_op_${ZE_AFFINITY_MASK} - cd test/python + export PYTORCH_ENABLE_XPU_FALLBACK=1 + cd test/xpu timeout 8000 pytest -v 2>&1 | tee torch_xpu_ops_ut.log - name: Run Torch XPU UT run: | diff --git a/test/python/examples/test_binary.py b/examples/test_binary.py similarity index 99% rename from test/python/examples/test_binary.py rename to examples/test_binary.py index fc21d94f7..5ac332c6d 100644 --- a/test/python/examples/test_binary.py +++ b/examples/test_binary.py @@ -59,7 +59,7 @@ def test_div_int(self, dtype=torch.float): c_cpu = a_cpu / b_cpu self.assertEqual(c_cpu.dtype, c_xpu.dtype) # assume float self.assertEqual(c_cpu, c_xpu.to(cpu_device)) - + def test_binary_div_channels_last(self, dtype=torch.float): shapes = [ (1, 2, 3, 4), @@ -222,7 +222,7 @@ def test_binary_div_channels_last(self, dtype=torch.float): self.assertEqual( a_xpu.is_contiguous(memory_format=torch.channels_last), False ) - + def test_pow(self, dtype=torch.float): x_cpu = torch.tensor(([2.5, 3.1, 1.3]), dtype=torch.float, device=cpu_device) x_xpu = torch.tensor( @@ -235,7 +235,7 @@ def test_pow(self, dtype=torch.float): self.assertEqual(torch.pow(x_cpu, y_cpu), torch.pow(x_xpu, y_xpu).cpu()) self.assertEqual(x_cpu.pow(y_cpu), x_xpu.pow(y_xpu).cpu()) self.assertEqual(x_cpu.pow_(y_cpu), x_xpu.pow_(y_xpu).cpu()) - + def test_binary_op(self, dtype=torch.float): x_cpu = torch.randn(5) diff --git a/test/python/examples/test_compare.py b/examples/test_compare.py similarity index 97% rename from test/python/examples/test_compare.py rename to examples/test_compare.py index 7a8569fcc..1522426cb 100644 --- a/test/python/examples/test_compare.py +++ b/examples/test_compare.py @@ -26,21 +26,21 @@ def _test_compare_fn(self, fn, dtype): y_xpu.zero_() fn(x1_xpu, x2_xpu, out=y_xpu) self.assertEqual(y_xpu.cpu(), y) - + def test_eq(self, dtype=torch.float): self._test_compare_fn(torch.eq, dtype) - + def test_ne(self, dtype=torch.float): self._test_compare_fn(torch.ne, dtype) - + def test_lt(self, dtype=torch.float): self._test_compare_fn(torch.lt, dtype) - + def test_le(self, dtype=torch.float): self._test_compare_fn(torch.le, dtype) - + def test_gt(self, dtype=torch.float): self._test_compare_fn(torch.gt, dtype) - + def test_ge(self, dtype=torch.float): self._test_compare_fn(torch.ge, dtype) diff --git a/test/python/examples/test_copy.py b/examples/test_copy.py similarity index 100% rename from test/python/examples/test_copy.py rename to examples/test_copy.py diff --git a/test/python/examples/test_loops.py b/examples/test_loops.py similarity index 99% rename from test/python/examples/test_loops.py rename to examples/test_loops.py index f06292575..05cd8347d 100644 --- a/test/python/examples/test_loops.py +++ b/examples/test_loops.py @@ -34,13 +34,13 @@ def _test_loops(self, dtype=torch.float): c = a + b + 1 c_xpu = a_xpu + b_xpu + 1 self.assertEqual(c, c_xpu.cpu()) - + def test_loops_float(self): self._test_loops(torch.float) - + def test_loops_half(self): self._test_loops(torch.half) - + def test_loops_bfloat16(self): self._test_loops(torch.bfloat16) diff --git a/test/python/examples/test_rand.py b/examples/test_rand.py similarity index 99% rename from test/python/examples/test_rand.py rename to examples/test_rand.py index 157f76702..7dfd5efdb 100644 --- a/test/python/examples/test_rand.py +++ b/examples/test_rand.py @@ -19,7 +19,7 @@ def test_distribution_normal(self, dtype=torch.float): torch.normal(mean=-3.0, std=1.2, size=(10000,), device=xpu_device, dtype=dtype, out=x_xpu) self.assertEqual(x_xpu.cpu().mean(), -3.0, rtol=tol, atol=tol) self.assertEqual(x_xpu.cpu().std(), 1.2, rtol=tol, atol=tol) - + def test_distribution_uniform(self, dtype=torch.float): tol = 1e-2 x_xpu = torch.tensor(list(range(10000)), device=xpu_device, dtype=dtype) diff --git a/test/python/examples/test_resize.py b/examples/test_resize.py similarity index 99% rename from test/python/examples/test_resize.py rename to examples/test_resize.py index 03f3e6e3e..7622d6c65 100644 --- a/test/python/examples/test_resize.py +++ b/examples/test_resize.py @@ -28,14 +28,14 @@ def test_view(self, dtype=torch.float): assert b_xpu.shape[2] == 2 self.assertEqual(c_cpu, b_xpu.to(cpu_device)) - + def test_view_as_real(self, dtype=torch.cfloat): a_cpu = torch.randn(2, 3, 4, dtype=dtype) a_xpu = a_cpu.to(xpu_device) b_cpu = torch.view_as_real(a_cpu) b_xpu = torch.view_as_real(a_xpu) self.assertEqual(b_cpu, b_xpu.to(cpu_device)) - + def test_view_as_complex(self, dtype=torch.float): a_cpu = torch.randn(109, 2, dtype=dtype) a_xpu = a_cpu.to(xpu_device) diff --git a/test/python/examples/test_tensor_factory.py b/examples/test_tensor_factory.py similarity index 100% rename from test/python/examples/test_tensor_factory.py rename to examples/test_tensor_factory.py diff --git a/test/python/examples/test_unary.py b/examples/test_unary.py similarity index 98% rename from test/python/examples/test_unary.py rename to examples/test_unary.py index e8864263f..996090f73 100644 --- a/test/python/examples/test_unary.py +++ b/examples/test_unary.py @@ -14,7 +14,7 @@ class Dtypes(object): def __init__(self, include_dtypes, exclude_dtypes=[]): self.include_dtypes = include_dtypes self.exclude_dtypes = exclude_dtypes - + def __call__(self, fn): def fn_out(*args, **kwargs): for dtype in self.include_dtypes: @@ -38,23 +38,23 @@ def _test_unary_out_ops(self, fn_str, dtype): d_cpu = eval(f"torch.{fn_str}(a_cpu, out=c_cpu)") d_xpu = eval(f"torch.{fn_str}(a_xpu, out=c_xpu)") self.assertEqual(c_cpu, c_xpu.cpu(), atol=1e-4, rtol=1e-4) - + @Dtypes(floating_types) def test_abs_out(self, dtype): self._test_unary_out_ops('abs', dtype) - + @Dtypes(floating_and_complex_types) def test_sin_out(self, dtype): self._test_unary_out_ops('sin', dtype) - + @Dtypes(floating_and_complex_types) def test_cos_out(self, dtype): self._test_unary_out_ops('cos', dtype) - + @Dtypes(floating_and_complex_types) def test_log_out(self, dtype): self._test_unary_out_ops('log', dtype) - + @Dtypes(floating_and_complex_types) def test_sqrt_out(self, dtype): self._test_unary_out_ops('sqrt', dtype) @@ -62,15 +62,15 @@ def test_sqrt_out(self, dtype): @Dtypes(floating_and_complex_types) def test_rsqrt_out(self, dtype): self._test_unary_out_ops('rsqrt', dtype) - + @Dtypes(floating_and_complex_types) def test_tanh_out(self, dtype): self._test_unary_out_ops('tanh', dtype) - + @Dtypes(all_basic_and_complex_types, [torch.bool]) def test_neg_out(self, dtype): self._test_unary_out_ops('neg', dtype) - + @Dtypes(floating_and_complex_types) def test_reciprocal_out(self, dtype): self._test_unary_out_ops('reciprocal', dtype) diff --git a/src/aten/BinaryOps.cpp b/src/aten/BinaryOps.cpp index 3f7372aba..503a1eddf 100644 --- a/src/aten/BinaryOps.cpp +++ b/src/aten/BinaryOps.cpp @@ -210,7 +210,7 @@ Tensor XPUNativeFunctions::rsub( const Tensor& self, const Tensor& other, const Scalar& alpha) { - return XPUNativeFunctions::sub(self, other, alpha); + return XPUNativeFunctions::sub(other, self, alpha); } Tensor& XPUNativeFunctions::rsub_out( @@ -218,7 +218,7 @@ Tensor& XPUNativeFunctions::rsub_out( const Tensor& other, const Scalar& alpha, Tensor& out) { - return XPUNativeFunctions::sub_out(self, other, alpha, out); + return XPUNativeFunctions::sub_out(other, self, alpha, out); } Tensor XPUNativeFunctions::rsub( @@ -226,7 +226,7 @@ Tensor XPUNativeFunctions::rsub( const Scalar& other, const Scalar& alpha) { return XPUNativeFunctions::sub( - self, native::wrapped_scalar_tensor(other), alpha); + native::wrapped_scalar_tensor(other), self, alpha); } Tensor& XPUNativeFunctions::rsub_out( @@ -235,7 +235,7 @@ Tensor& XPUNativeFunctions::rsub_out( const Scalar& alpha, Tensor& out) { return XPUNativeFunctions::sub_out( - self, native::wrapped_scalar_tensor(other), alpha, out); + native::wrapped_scalar_tensor(other), self, alpha, out); } Tensor XPUNativeFunctions::remainder(const Tensor& self, const Tensor& other) { diff --git a/test/xpu/test_ops.py b/test/xpu/test_ops.py new file mode 100644 index 000000000..ae6bb73eb --- /dev/null +++ b/test/xpu/test_ops.py @@ -0,0 +1,391 @@ +# Owner(s): ["module: intel"] + +import sys +import unittest + +import torch +from torch.testing._internal.common_dtype import ( + floating_and_complex_types_and, + all_types_and_complex_and, + integral_types_and, +) +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyXPU, + OpDTypes, + ops, +) +from torch.testing._internal.common_methods_invocations import ( + ops_and_refs, + python_ref_db, +) +from torch.testing._internal.common_utils import ( + NoTest, + run_tests, + suppress_warnings, + TEST_WITH_UBSAN, + TEST_XPU, + TestCase, +) + +if not TEST_XPU: + print("XPU not available, skipping tests", file=sys.stderr) + TestCase = NoTest # noqa: F811 + +any_common_cpu_xpu_one = OpDTypes.any_common_cpu_cuda_one +_xpu_computation_op_list = [ + "fill", + "zeros", + "zeros_like", + "clone", + "view_as_real", + "view_as_complex", + "view", + "resize_", + "resize_as_", + "add", + "sub", + "mul", + "div", + "abs", + "rsub", + "remainder", + "fmod", + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "sin", + "cos", + "log", + "sqrt", + "rsqrt", + "tanh", + "neg", + "reciprocal", + "pow", + "unfold", +] +_xpu_tensor_factory_op_list = [ + "normal", + "uniform", + "as_strided", + "empty", + "empty_strided", +] +_xpu_all_op_list = _xpu_computation_op_list + _xpu_tensor_factory_op_list +_xpu_all_ops = [op for op in ops_and_refs if op.name in _xpu_all_op_list] + +# Exclusive ops list +_xpu_not_test_dtype_op_list = [ + "resize_", # Skipped by CPU + "resize_as_", # Skipped by CPU + "abs", # Not aligned dtype +] +_xpu_float_only_op_list = [ + "reciprocal", # Align with CUDA impl. Only float and complex supported in CUDA native. +] + +# test_compare_cpu +_xpu_computation_ops = [ + op for op in ops_and_refs if op.name in _xpu_computation_op_list +] + +# test_non_standard_bool_values +_xpu_non_standard_bool_values_op_list = _xpu_computation_op_list.copy() +for op in _xpu_float_only_op_list: + _xpu_non_standard_bool_values_op_list.remove(op) +_xpu_non_standard_bool_values_ops = [op for op in ops_and_refs if op.name in _xpu_non_standard_bool_values_op_list] + +# test_dtypes +_xpu_dtype_op_list = _xpu_all_op_list.copy() +for op in _xpu_not_test_dtype_op_list: + _xpu_dtype_op_list.remove(op) +for op in _xpu_float_only_op_list: + _xpu_dtype_op_list.remove(op) +_xpu_dtype_ops = [op for op in ops_and_refs if op.name in _xpu_dtype_op_list] + +# test_promotes_int_to_float +_xpu_promotes_int_to_float_list = _xpu_all_op_list.copy() +for op in _xpu_float_only_op_list: + _xpu_promotes_int_to_float_list.remove(op) +_xpu_promotes_int_to_float_ops = [op for op in ops_and_refs if op.name in _xpu_promotes_int_to_float_list] + + +class TestXpu(TestCase): + + @onlyXPU + @suppress_warnings + @ops(_xpu_computation_ops, dtypes=any_common_cpu_xpu_one) + def test_compare_cpu(self, device, dtype, op): + def to_cpu(arg): + if isinstance(arg, torch.Tensor): + return arg.to(device="cpu") + return arg + + samples = op.reference_inputs(device, dtype) + + for sample in samples: + cpu_sample = sample.transform(to_cpu) + xpu_results = op(sample.input, *sample.args, **sample.kwargs) + cpu_results = op(cpu_sample.input, *cpu_sample.args, **cpu_sample.kwargs) + + xpu_results = sample.output_process_fn_grad(xpu_results) + cpu_results = cpu_sample.output_process_fn_grad(cpu_results) + + # Lower tolerance because we are running this as a `@slowTest` + # Don't want the periodic tests to fail frequently + self.assertEqual(xpu_results, cpu_results, atol=1e-4, rtol=1e-4) + + @onlyXPU + @ops(_xpu_non_standard_bool_values_ops, allowed_dtypes=(torch.bool,)) + @unittest.skipIf(TEST_WITH_UBSAN, "Test uses undefined behavior") + def test_non_standard_bool_values(self, device, dtype, op): + # Test boolean values other than 0x00 and 0x01 (gh-54789) + def convert_boolean_tensors(x): + if not isinstance(x, torch.Tensor) or x.dtype != torch.bool: + return x + + # Map False -> 0 and True -> Random value in [2, 255] + true_vals = torch.randint( + 2, 255, x.shape, dtype=torch.uint8, device=x.device + ) + false_vals = torch.zeros((), dtype=torch.uint8, device=x.device) + x_int = torch.where(x, true_vals, false_vals) + + ret = x_int.view(torch.bool) + self.assertEqual(ret, x) + return ret + + for sample in op.sample_inputs(device, dtype): + expect = op(sample.input, *sample.args, **sample.kwargs) + + transformed = sample.transform(convert_boolean_tensors) + actual = op(transformed.input, *transformed.args, **transformed.kwargs) + + self.assertEqual(expect, actual) + + @onlyXPU + @ops(_xpu_dtype_ops, dtypes=OpDTypes.none) + def test_dtypes(self, device, op): + # Check complex32 support only if the op claims. + # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. + device_type = torch.device(device).type + include_complex32 = ( + (torch.complex32,) + if op.supports_dtype(torch.complex32, device_type) + else () + ) + + # dtypes to try to backward in + allowed_backward_dtypes = floating_and_complex_types_and( + *((torch.half, torch.bfloat16) + include_complex32) + ) + + # lists for (un)supported dtypes + supported_dtypes = set() + unsupported_dtypes = set() + supported_backward_dtypes = set() + unsupported_backward_dtypes = set() + dtype_error: Dict[torch.dtype, Exception] = dict() + + def unsupported(dtype, e): + dtype_error[dtype] = e + unsupported_dtypes.add(dtype) + if dtype in allowed_backward_dtypes: + unsupported_backward_dtypes.add(dtype) + + for dtype in all_types_and_complex_and( + *((torch.half, torch.bfloat16, torch.bool) + include_complex32) + ): + # tries to acquire samples - failure indicates lack of support + requires_grad = dtype in allowed_backward_dtypes + try: + samples = tuple( + op.sample_inputs(device, dtype, requires_grad=requires_grad) + ) + except Exception as e: + unsupported(dtype, e) + continue + + for sample in samples: + # tries to call operator with the sample - failure indicates + # lack of support + try: + result = op(sample.input, *sample.args, **sample.kwargs) + supported_dtypes.add(dtype) + except Exception as e: + # NOTE: some ops will fail in forward if their inputs + # require grad but they don't support computing the gradient + # in that type! This is a bug in the op! + unsupported(dtype, e) + continue + + # Checks for backward support in the same dtype, if the input has + # one or more tensors requiring grad + def _tensor_requires_grad(x): + if isinstance(x, dict): + for v in x.values(): + if _tensor_requires_grad(v): + return True + if isinstance(x, (list, tuple)): + for a in x: + if _tensor_requires_grad(a): + return True + if isinstance(x, torch.Tensor) and x.requires_grad: + return True + + return False + + requires_grad = ( + _tensor_requires_grad(sample.input) + or _tensor_requires_grad(sample.args) + or _tensor_requires_grad(sample.kwargs) + ) + if not requires_grad: + continue + + try: + result = sample.output_process_fn_grad(result) + if isinstance(result, torch.Tensor): + backward_tensor = result + elif isinstance(result, Sequence) and isinstance( + result[0], torch.Tensor + ): + backward_tensor = result[0] + else: + continue + + # Note: this grad may not have the same dtype as dtype + # For functions like complex (float -> complex) or abs + # (complex -> float) the grad tensor will have a + # different dtype than the input. + # For simplicity, this is still modeled as these ops + # supporting grad in the input dtype. + grad = torch.randn_like(backward_tensor) + backward_tensor.backward(grad) + supported_backward_dtypes.add(dtype) + except Exception as e: + dtype_error[dtype] = e + unsupported_backward_dtypes.add(dtype) + + # Checks that dtypes are listed correctly and generates an informative + # error message + + supported_forward = supported_dtypes - unsupported_dtypes + partially_supported_forward = supported_dtypes & unsupported_dtypes + unsupported_forward = unsupported_dtypes - supported_dtypes + supported_backward = supported_backward_dtypes - unsupported_backward_dtypes + partially_supported_backward = ( + supported_backward_dtypes & unsupported_backward_dtypes + ) + unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes + + device_type = torch.device(device).type + + claimed_forward = set(op.supported_dtypes(device_type)) + supported_but_unclaimed_forward = supported_forward - claimed_forward + claimed_but_unsupported_forward = claimed_forward & unsupported_forward + + claimed_backward = set(op.supported_backward_dtypes(device_type)) + supported_but_unclaimed_backward = supported_backward - claimed_backward + claimed_but_unsupported_backward = claimed_backward & unsupported_backward + + # Partially supporting a dtype is not an error, but we print a warning + if (len(partially_supported_forward) + len(partially_supported_backward)) > 0: + msg = f"Some dtypes for {op.name} on device type {device_type} are only partially supported!\n" + if len(partially_supported_forward) > 0: + msg = ( + msg + + "The following dtypes only worked on some samples during forward: {}.\n".format( + partially_supported_forward + ) + ) + if len(partially_supported_backward) > 0: + msg = ( + msg + + "The following dtypes only worked on some samples during backward: {}.\n".format( + partially_supported_backward + ) + ) + print(msg) + + if ( + len(supported_but_unclaimed_forward) + + len(claimed_but_unsupported_forward) + + len(supported_but_unclaimed_backward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + + # Reference operators often support additional dtypes, and that's OK + if op in python_ref_db: + if ( + len(claimed_but_unsupported_forward) + + len(claimed_but_unsupported_backward) + ) == 0: + return + + # Generates error msg + msg = f"The supported dtypes for {op.name} on device type {device_type} are incorrect!\n" + if len(supported_but_unclaimed_forward) > 0: + msg = ( + msg + + "The following dtypes worked in forward but are not listed by the OpInfo: {}.\n".format( + supported_but_unclaimed_forward + ) + ) + if len(supported_but_unclaimed_backward) > 0: + msg = ( + msg + + "The following dtypes worked in backward but are not listed by the OpInfo: {}.\n".format( + supported_but_unclaimed_backward + ) + ) + if len(claimed_but_unsupported_forward) > 0: + msg = ( + msg + + "The following dtypes did not work in forward but are listed by the OpInfo: {}.\n".format( + claimed_but_unsupported_forward + ) + ) + if len(claimed_but_unsupported_backward) > 0: + msg = ( + msg + + "The following dtypes did not work in backward but are listed by the OpInfo: {}.\n".format( + claimed_but_unsupported_backward + ) + ) + + all_claimed_but_unsupported = set.union( + claimed_but_unsupported_backward, claimed_but_unsupported_forward + ) + if all_claimed_but_unsupported: + msg += "Unexpected failures raised the following errors:\n" + for dtype in all_claimed_but_unsupported: + msg += f"{dtype} - {dtype_error[dtype]}\n" + + self.fail(msg) + + # Validates that each OpInfo that sets promotes_int_to_float=True does as it says + @onlyXPU + @ops( + (op for op in _xpu_promotes_int_to_float_ops if op.promotes_int_to_float), + allowed_dtypes=integral_types_and(torch.bool), + ) + def test_promotes_int_to_float(self, device, dtype, op): + for sample in op.sample_inputs(device, dtype): + output = op(sample.input, *sample.args, **sample.kwargs) + if not output.dtype.is_floating_point: + self.fail( + f"The OpInfo sets `promotes_int_to_float=True`, but {dtype} was promoted to {output.dtype}." + ) + + +instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") + + +if __name__ == "__main__": + run_tests()