diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index f8fd76f562cf10..ac7ed07d4b3e22 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1364,6 +1364,20 @@ def fn(x): r2 = opt_fn(i) self.assertEqual(r1, r2) + def test_tensor_hasattr(self): + @torch.compile(fullgraph=True) + def fn(x): + if hasattr(x, "test"): + return x + 2 + else: + return x + 1 + + self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2))) + + inp = torch.ones(2, 2) + inp.test = None + self.assertEqual(torch.ones(2, 2) + 2, fn(inp)) + def test_shape_unpack(self): def fn(x): a, b = x.size() diff --git a/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_copy_cpu b/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_copy_cpu deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map2_cpu b/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map2_cpu deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map_cpu b/test/dynamo_expected_failures/TestTorchDeviceTypeCPU.test_broadcast_fn_map_cpu deleted file mode 100644 index e69de29bb2d1d6..00000000000000 diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 9077f00c037f81..f4deea9f383615 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -333,6 +333,27 @@ def method_attr__version(self, tx): tx, [self], {} ) + def call_hasattr(self, tx, name): + from . import GetAttrVariable + from .builtin import BuiltinVariable + + try: + var = BuiltinVariable(getattr).call_function( + tx, [self, ConstantVariable(name)], {} + ) + # in the event that TensorVariable returns NotImplemented + # BuiltinVariable.call_getattr returns GetAttrVariable + ret_val = not isinstance(var, GetAttrVariable) + except AttributeError: + ret_val = False + + if self.source: + install_guard( + AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) + ) + + return ConstantVariable(ret_val) + def var_getattr(self, tx, name): from . import UserDefinedClassVariable