Skip to content

Commit

Permalink
Align with PyTorch test_ops case suites (#50)
Browse files Browse the repository at this point in the history
Signed-off-by: Feng Yuan <[email protected]>
Co-authored-by: Yutao Xu <[email protected]>
Co-authored-by: chuanqiw <[email protected]>
  • Loading branch information
3 people authored Mar 25, 2024
1 parent 9ce96cf commit e755b1e
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 33 deletions.
1 change: 0 additions & 1 deletion .github/ci_commit_pins/pytorch.txt

This file was deleted.

9 changes: 5 additions & 4 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:
run: |
pwd
cd ../ && rm -rf pytorch
git clone https://github.com/pytorch/pytorch
cd pytorch && git checkout `cat ../torch-xpu-ops/.github/ci_commit_pins/pytorch.txt` && git submodule sync && git submodule update --init --recursive
git clone -b nightly https://github.com/pytorch/pytorch
cd pytorch && git log -n 1 && git submodule sync && git submodule update --init --recursive
rm -rf third_party/torch-xpu-ops && cp -r ../torch-xpu-ops third_party/
# Workaround for torch-xpu-ops ci test
sed -i "s/checkout --quiet \${TORCH_XPU_OPS_COMMIT}/log -n 1/g" caffe2/CMakeLists.txt
Expand All @@ -52,11 +52,12 @@ jobs:
cd examples
timeout 8000 pytest -v 2>&1 | tee torch_xpu_ops_example.log
- name: Run XPU OP UT
if: ${{ hashFiles('test/python/') != '' }}
if: ${{ hashFiles('test/xpu/') != '' }}
run: |
source /opt/intel/oneapi/compiler/latest/env/vars.sh
source activate xpu_op_${ZE_AFFINITY_MASK}
cd test/python
export PYTORCH_ENABLE_XPU_FALLBACK=1
cd test/xpu
timeout 8000 pytest -v 2>&1 | tee torch_xpu_ops_ut.log
- name: Run Torch XPU UT
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_div_int(self, dtype=torch.float):
c_cpu = a_cpu / b_cpu
self.assertEqual(c_cpu.dtype, c_xpu.dtype) # assume float
self.assertEqual(c_cpu, c_xpu.to(cpu_device))

def test_binary_div_channels_last(self, dtype=torch.float):
shapes = [
(1, 2, 3, 4),
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_binary_div_channels_last(self, dtype=torch.float):
self.assertEqual(
a_xpu.is_contiguous(memory_format=torch.channels_last), False
)

def test_pow(self, dtype=torch.float):
x_cpu = torch.tensor(([2.5, 3.1, 1.3]), dtype=torch.float, device=cpu_device)
x_xpu = torch.tensor(
Expand All @@ -235,7 +235,7 @@ def test_pow(self, dtype=torch.float):
self.assertEqual(torch.pow(x_cpu, y_cpu), torch.pow(x_xpu, y_xpu).cpu())
self.assertEqual(x_cpu.pow(y_cpu), x_xpu.pow(y_xpu).cpu())
self.assertEqual(x_cpu.pow_(y_cpu), x_xpu.pow_(y_xpu).cpu())

def test_binary_op(self, dtype=torch.float):
x_cpu = torch.randn(5)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def _test_compare_fn(self, fn, dtype):
y_xpu.zero_()
fn(x1_xpu, x2_xpu, out=y_xpu)
self.assertEqual(y_xpu.cpu(), y)

def test_eq(self, dtype=torch.float):
self._test_compare_fn(torch.eq, dtype)

def test_ne(self, dtype=torch.float):
self._test_compare_fn(torch.ne, dtype)

def test_lt(self, dtype=torch.float):
self._test_compare_fn(torch.lt, dtype)

def test_le(self, dtype=torch.float):
self._test_compare_fn(torch.le, dtype)

def test_gt(self, dtype=torch.float):
self._test_compare_fn(torch.gt, dtype)

def test_ge(self, dtype=torch.float):
self._test_compare_fn(torch.ge, dtype)
File renamed without changes.
6 changes: 3 additions & 3 deletions test/python/examples/test_loops.py → examples/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def _test_loops(self, dtype=torch.float):
c = a + b + 1
c_xpu = a_xpu + b_xpu + 1
self.assertEqual(c, c_xpu.cpu())

def test_loops_float(self):
self._test_loops(torch.float)

def test_loops_half(self):
self._test_loops(torch.half)

def test_loops_bfloat16(self):
self._test_loops(torch.bfloat16)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_distribution_normal(self, dtype=torch.float):
torch.normal(mean=-3.0, std=1.2, size=(10000,), device=xpu_device, dtype=dtype, out=x_xpu)
self.assertEqual(x_xpu.cpu().mean(), -3.0, rtol=tol, atol=tol)
self.assertEqual(x_xpu.cpu().std(), 1.2, rtol=tol, atol=tol)

def test_distribution_uniform(self, dtype=torch.float):
tol = 1e-2
x_xpu = torch.tensor(list(range(10000)), device=xpu_device, dtype=dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def test_view(self, dtype=torch.float):
assert b_xpu.shape[2] == 2

self.assertEqual(c_cpu, b_xpu.to(cpu_device))

def test_view_as_real(self, dtype=torch.cfloat):
a_cpu = torch.randn(2, 3, 4, dtype=dtype)
a_xpu = a_cpu.to(xpu_device)
b_cpu = torch.view_as_real(a_cpu)
b_xpu = torch.view_as_real(a_xpu)
self.assertEqual(b_cpu, b_xpu.to(cpu_device))

def test_view_as_complex(self, dtype=torch.float):
a_cpu = torch.randn(109, 2, dtype=dtype)
a_xpu = a_cpu.to(xpu_device)
Expand Down
File renamed without changes.
18 changes: 9 additions & 9 deletions test/python/examples/test_unary.py → examples/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Dtypes(object):
def __init__(self, include_dtypes, exclude_dtypes=[]):
self.include_dtypes = include_dtypes
self.exclude_dtypes = exclude_dtypes

def __call__(self, fn):
def fn_out(*args, **kwargs):
for dtype in self.include_dtypes:
Expand All @@ -38,39 +38,39 @@ def _test_unary_out_ops(self, fn_str, dtype):
d_cpu = eval(f"torch.{fn_str}(a_cpu, out=c_cpu)")
d_xpu = eval(f"torch.{fn_str}(a_xpu, out=c_xpu)")
self.assertEqual(c_cpu, c_xpu.cpu(), atol=1e-4, rtol=1e-4)

@Dtypes(floating_types)
def test_abs_out(self, dtype):
self._test_unary_out_ops('abs', dtype)

@Dtypes(floating_and_complex_types)
def test_sin_out(self, dtype):
self._test_unary_out_ops('sin', dtype)

@Dtypes(floating_and_complex_types)
def test_cos_out(self, dtype):
self._test_unary_out_ops('cos', dtype)

@Dtypes(floating_and_complex_types)
def test_log_out(self, dtype):
self._test_unary_out_ops('log', dtype)

@Dtypes(floating_and_complex_types)
def test_sqrt_out(self, dtype):
self._test_unary_out_ops('sqrt', dtype)

@Dtypes(floating_and_complex_types)
def test_rsqrt_out(self, dtype):
self._test_unary_out_ops('rsqrt', dtype)

@Dtypes(floating_and_complex_types)
def test_tanh_out(self, dtype):
self._test_unary_out_ops('tanh', dtype)

@Dtypes(all_basic_and_complex_types, [torch.bool])
def test_neg_out(self, dtype):
self._test_unary_out_ops('neg', dtype)

@Dtypes(floating_and_complex_types)
def test_reciprocal_out(self, dtype):
self._test_unary_out_ops('reciprocal', dtype)
8 changes: 4 additions & 4 deletions src/aten/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,23 @@ Tensor XPUNativeFunctions::rsub(
const Tensor& self,
const Tensor& other,
const Scalar& alpha) {
return XPUNativeFunctions::sub(self, other, alpha);
return XPUNativeFunctions::sub(other, self, alpha);
}

Tensor& XPUNativeFunctions::rsub_out(
const Tensor& self,
const Tensor& other,
const Scalar& alpha,
Tensor& out) {
return XPUNativeFunctions::sub_out(self, other, alpha, out);
return XPUNativeFunctions::sub_out(other, self, alpha, out);
}

Tensor XPUNativeFunctions::rsub(
const Tensor& self,
const Scalar& other,
const Scalar& alpha) {
return XPUNativeFunctions::sub(
self, native::wrapped_scalar_tensor(other), alpha);
native::wrapped_scalar_tensor(other), self, alpha);
}

Tensor& XPUNativeFunctions::rsub_out(
Expand All @@ -235,7 +235,7 @@ Tensor& XPUNativeFunctions::rsub_out(
const Scalar& alpha,
Tensor& out) {
return XPUNativeFunctions::sub_out(
self, native::wrapped_scalar_tensor(other), alpha, out);
native::wrapped_scalar_tensor(other), self, alpha, out);
}

Tensor XPUNativeFunctions::remainder(const Tensor& self, const Tensor& other) {
Expand Down
Loading

0 comments on commit e755b1e

Please sign in to comment.