diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 67ab47dfac9..f5cbfb8760c 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -1683,6 +1683,22 @@ class FunctionTests(torch._dynamo.test_case.TestCase): tmp = {1: "D", 10: "B", 3: "E", 0: "F"} return x + 1, sorted(tmp), sorted(tmp, reverse=True) + def test_dict_hasattr(self): + def fn(x): + if hasattr(x, "to"): + return x.to("cpu") + if hasattr(x, "items"): + return torch.cos(x["a"]) + return x + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + x = dict(a=torch.randn(3)) + self.assertEqual(fn(x), opt_fn(x)) + + x = torch.randn(4) + self.assertEqual(fn(x), opt_fn(x)) + @make_test def test_list_clear(a, b): tmp = [a + 1, a + 2] diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 2d4c015bf3a..f9b74bb2f20 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -355,6 +355,15 @@ class ConstDictVariable(VariableTracker): def unpack_var_sequence(self, tx): return [x.vt for x in self.items.keys()] + def call_hasattr(self, tx, name): + # dict not allow setting arbitrary attributes. To check for hasattr, we can just check the __dict__ of the dict. + # OrderedDict though requires side effects tracking because it supports arbitrary setattr. + if self.user_cls is dict: + if name in self.user_cls.__dict__: + return ConstantVariable.create(True) + return ConstantVariable.create(False) + unimplemented(f"hasattr on {self.user_cls} is not supported") + class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: