mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add hasattr for tensor variable (#131008)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131008 Approved by: https://github.com/anijain2305 ghstack dependencies: #131007
This commit is contained in:
parent
1f961ad495
commit
1b72cf0b09
|
|
@ -1364,6 +1364,20 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||||
r2 = opt_fn(i)
|
r2 = opt_fn(i)
|
||||||
self.assertEqual(r1, r2)
|
self.assertEqual(r1, r2)
|
||||||
|
|
||||||
|
def test_tensor_hasattr(self):
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def fn(x):
|
||||||
|
if hasattr(x, "test"):
|
||||||
|
return x + 2
|
||||||
|
else:
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
self.assertEqual(torch.ones(2, 2) + 1, fn(torch.ones(2, 2)))
|
||||||
|
|
||||||
|
inp = torch.ones(2, 2)
|
||||||
|
inp.test = None
|
||||||
|
self.assertEqual(torch.ones(2, 2) + 2, fn(inp))
|
||||||
|
|
||||||
def test_shape_unpack(self):
|
def test_shape_unpack(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
a, b = x.size()
|
a, b = x.size()
|
||||||
|
|
|
||||||
|
|
@ -333,6 +333,27 @@ class TensorVariable(VariableTracker):
|
||||||
tx, [self], {}
|
tx, [self], {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def call_hasattr(self, tx, name):
|
||||||
|
from . import GetAttrVariable
|
||||||
|
from .builtin import BuiltinVariable
|
||||||
|
|
||||||
|
try:
|
||||||
|
var = BuiltinVariable(getattr).call_function(
|
||||||
|
tx, [self, ConstantVariable(name)], {}
|
||||||
|
)
|
||||||
|
# in the event that TensorVariable returns NotImplemented
|
||||||
|
# BuiltinVariable.call_getattr returns GetAttrVariable
|
||||||
|
ret_val = not isinstance(var, GetAttrVariable)
|
||||||
|
except AttributeError:
|
||||||
|
ret_val = False
|
||||||
|
|
||||||
|
if self.source:
|
||||||
|
install_guard(
|
||||||
|
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ConstantVariable(ret_val)
|
||||||
|
|
||||||
def var_getattr(self, tx, name):
|
def var_getattr(self, tx, name):
|
||||||
from . import UserDefinedClassVariable
|
from . import UserDefinedClassVariable
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user