From d64bc8f0f81bd9b514eb1a5ee6f5b03094e4e6e9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 13 Nov 2023 14:29:17 +0000 Subject: [PATCH] use sourceless builder for builtin getattr (#113340) In TorchVision we use the following (simplified) dispatch mechanism: ```python import torch def kernel1(tensor): return tensor + 2 def dispatcher1(input): kernel = get_kernel(dispatcher1, type(input)) return kernel(input) def kernel2(tensor): return tensor - 2 def dispatcher2(input): kernel = get_kernel(dispatcher2, type(input)) return kernel(input) # We actually use the function and type as keys, rather than their names. # However, this currently not supported, but should be easy to add after # https://github.com/pytorch/pytorch/pull/111196 REGISTRY = { "dispatcher1": {"Tensor": kernel1}, "dispatcher2": {"Tensor": kernel2}, } def get_kernel(dispatcher, input_type): dispatcher_registry = REGISTRY[dispatcher.__name__] for cls in input_type.__mro__: kernel = dispatcher_registry[cls.__name__] break return kernel ``` This can be compiled without graph breaks: ```python cfn = torch.compile(dispatcher1, fullgraph=True) torch.testing.assert_close(int(cfn(torch.tensor(3))), 5) cfn = torch.compile(dispatcher2, fullgraph=True) torch.testing.assert_close(int(cfn(torch.tensor(3))), 1) ``` However, if we start chaining these calls, we hit some issues: ```python class Pipeline(torch.nn.Module): def forward(self, input): input = dispatcher1(input) input = dispatcher2(input) return input cfn = torch.compile(Pipeline(), fullgraph=True) torch.testing.assert_close(int(cfn(torch.tensor(3))), 3) ``` ``` Can't access members of type(obj) for a generated custom object. Please use __class__ instead ``` The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom: https://github.com/pytorch/pytorch/blob/addb8e29cd842e1a290cb0b55662ee0423ab2498/torch/_dynamo/variables/builtin.py#L1264-L1268 And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead. With that fix in place, we hit another pointing to `input_type.__mro__` ``` AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally. ``` Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`: https://github.com/pytorch/pytorch/blob/addb8e29cd842e1a290cb0b55662ee0423ab2498/torch/_dynamo/variables/builtin.py#L1167-L1168 Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340 Approved by: https://github.com/peterbell10, https://github.com/lezcano --- test/dynamo/test_export.py | 13 +++--------- torch/_dynamo/variables/builtin.py | 34 +++++++++++------------------- 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 5938b25e5de..2653f08b4f8 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2851,7 +2851,8 @@ def forward(self, x): with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"): gm(torch.randn(3, 4, 5)) - def test_access_class_method_from_user_class(self): + @common_utils.parametrize("type_fn", [type, lambda obj: obj.__class__]) + def test_access_class_method_from_user_class(self, type_fn): class A: @classmethod def func(cls): @@ -2859,19 +2860,11 @@ def forward(self, x): def f(x): a = A() - return x.sum() + type(a).func().sum() + return x.sum() + type_fn(a).func().sum() gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4)) self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4))) - def f_correct(x): - a = A() - return x.sum() + a.__class__.func().sum() - - gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4)) - - self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4))) - def test_not_functionalize(self): class Foo(torch.nn.Module): def __init__(self): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c369f1681e9..0cb43d4be9d 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -11,7 +11,6 @@ import torch from torch import sym_float, sym_int from .. import config, polyfill, variables -from ..allowed_functions import is_allowed from ..exc import ( AttributeMutationError, unimplemented, @@ -1177,12 +1176,10 @@ class BuiltinVariable(VariableTracker): return build_checkpoint_variable(**options) elif trace_rules.lookup(member) is not None: return trace_rules.lookup(member)(member, **options) - elif is_allowed(member): - return TorchVariable(member, **options) - elif ConstantVariable.is_literal(member): - return ConstantVariable.create(member, **options) - else: + elif source is not None: return VariableBuilder(tx, source)(member) + else: + return SourcelessBuilder()(tx, member) elif isinstance(obj, (PythonModuleVariable, DummyModule)): member = obj.value.__dict__[name] @@ -1278,24 +1275,17 @@ class BuiltinVariable(VariableTracker): try: py_type = obj.python_type() - except NotImplementedError: - py_type = None + except NotImplementedError as error: + raise UserError( + UserErrorType.INVALID_INPUT, + str(error), + case_name="unknown_python_type", + ) from None - if istype(obj, variables.TupleVariable): - return BuiltinVariable(py_type) - - if py_type is not None and obj.source: - return VariableBuilder(tx, TypeSource(obj.source))(py_type) - - if py_type is not None: + if obj.source is None: return SourcelessBuilder()(tx, py_type) - - raise UserError( - UserErrorType.ANTI_PATTERN, - f"Can't call type() on generated custom object {obj}. " - "Please use __class__ instead", - case_name="type_reflection_method", - ) + else: + return VariableBuilder(tx, TypeSource(obj.source))(py_type) def call_reversed(self, tx, obj: VariableTracker): if obj.has_unpack_var_sequence(tx):