Skip to content

Commit

Permalink
Add hasattr for tensor variable (pytorch#131008)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#131008
Approved by: https://github.com/anijain2305
ghstack dependencies: pytorch#131007
  • Loading branch information
mlazos authored and pytorchmergebot committed Jul 19, 2024
1 parent 1f961ad commit 1b72cf0
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 0 deletions.
14 changes: 14 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,20 @@ def fn(x):
r2 = opt_fn(i)
self.assertEqual(r1, r2)

def test_tensor_hasattr(self):
@torch.compile(fullgraph=True)
def fn(x):
if hasattr(x, "test"):
return x + 2
else:
return x + 1

self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))

inp = torch.ones(2, 2)
inp.test = None
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))

def test_shape_unpack(self):
def fn(x):
a, b = x.size()
Expand Down
Empty file.
Empty file.
Empty file.
21 changes: 21 additions & 0 deletions torch/_dynamo/variables/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,27 @@ def method_attr__version(self, tx):
tx, [self], {}
)

def call_hasattr(self, tx, name):
from . import GetAttrVariable
from .builtin import BuiltinVariable

try:
var = BuiltinVariable(getattr).call_function(
tx, [self, ConstantVariable(name)], {}
)
# in the event that TensorVariable returns NotImplemented
# BuiltinVariable.call_getattr returns GetAttrVariable
ret_val = not isinstance(var, GetAttrVariable)
except AttributeError:
ret_val = False

if self.source:
install_guard(
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
)

return ConstantVariable(ret_val)

def var_getattr(self, tx, name):
from . import UserDefinedClassVariable

Expand Down

0 comments on commit 1b72cf0

Please sign in to comment.