From 9c635bf86e9dcfd52fc0e32384012a1cb1f075f5 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 3 Nov 2024 17:09:35 +0800 Subject: [PATCH 1/3] support numpy scalar as input type case --- python/paddle/tensor/creation.py | 2 +- test/legacy_test/test_eye_op.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c08babbfe69457..1a3d5e574c784f 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1363,7 +1363,7 @@ def _check_attr(attr, message): assert len(attr.shape) == 0 or ( len(attr.shape) == 1 and attr.shape[0] in [1, -1] ) - elif not isinstance(attr, int) or attr < 0: + elif not np.isscalar(attr) or attr < 0: raise TypeError(f"{message} should be a non-negative int.") _check_attr(num_rows, "num_rows") diff --git a/test/legacy_test/test_eye_op.py b/test/legacy_test/test_eye_op.py index 0ce01913e12855..7dadbb253b7a56 100644 --- a/test/legacy_test/test_eye_op.py +++ b/test/legacy_test/test_eye_op.py @@ -91,6 +91,22 @@ def test_check_output(self): self.check_output(check_pir=True) +class TestEyeOp3(OpTest): + def setUp(self): + ''' + Test eye op with specified shape + ''' + self.python_api = paddle.eye + self.op_type = "eye" + + self.inputs = {} + self.attrs = {'num_rows': np.int32(99), 'num_columns': np.int32(1)} + self.outputs = {'Out': np.eye(99, 1, dtype=float)} + + def test_check_output(self): + self.check_output(check_pir=True) + + class API_TestTensorEye(unittest.TestCase): def test_static_out(self): From 73be02d7e027c19bb2fa29d8814ea00db9a68535 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 4 Nov 2024 10:52:35 +0800 Subject: [PATCH 2/3] update docstring --- test/legacy_test/test_eye_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_eye_op.py b/test/legacy_test/test_eye_op.py index 7dadbb253b7a56..02a36f4a630c89 100644 --- a/test/legacy_test/test_eye_op.py +++ b/test/legacy_test/test_eye_op.py @@ -94,7 +94,7 @@ def test_check_output(self): class TestEyeOp3(OpTest): def setUp(self): ''' - Test eye op with specified shape + Test eye op with np.int32 scalar ''' self.python_api = paddle.eye self.op_type = "eye" From d176130329f532a72fdae571e853083d6bbfdfde Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 4 Nov 2024 14:37:46 +0800 Subject: [PATCH 3/3] fix --- python/paddle/tensor/creation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 1a3d5e574c784f..3a190fa0b06340 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1363,7 +1363,7 @@ def _check_attr(attr, message): assert len(attr.shape) == 0 or ( len(attr.shape) == 1 and attr.shape[0] in [1, -1] ) - elif not np.isscalar(attr) or attr < 0: + elif not isinstance(attr, (int, np.integer)) or attr < 0: raise TypeError(f"{message} should be a non-negative int.") _check_attr(num_rows, "num_rows")