use sourceless builder for builtin getattr (#113340)

In TorchVision we use the following (simplified) dispatch mechanism:

```python
import torch

def kernel1(tensor):
    return tensor + 2

def dispatcher1(input):
    kernel = get_kernel(dispatcher1, type(input))
    return kernel(input)

def kernel2(tensor):
    return tensor - 2

def dispatcher2(input):
    kernel = get_kernel(dispatcher2, type(input))
    return kernel(input)

# We actually use the function and type as keys, rather than their names.
# However, this currently not supported, but should be easy to add after
# https://github.com/pytorch/pytorch/pull/111196
REGISTRY = {
    "dispatcher1": {"Tensor": kernel1},
    "dispatcher2": {"Tensor": kernel2},
}

def get_kernel(dispatcher, input_type):
    dispatcher_registry = REGISTRY[dispatcher.__name__]
    for cls in input_type.__mro__:
        kernel = dispatcher_registry[cls.__name__]
        break
    return kernel
```

This can be compiled without graph breaks:

```python
cfn = torch.compile(dispatcher1, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 5)

cfn = torch.compile(dispatcher2, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 1)
```

However, if we start chaining these calls, we hit some issues:

```python
class Pipeline(torch.nn.Module):
    def forward(self, input):
        input = dispatcher1(input)
        input = dispatcher2(input)
        return input

cfn = torch.compile(Pipeline(), fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 3)
```

```
Can't access members of type(obj) for a generated custom object. Please use __class__ instead
```

The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1264-L1268)

And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead.

With that fix in place, we hit another pointing to `input_type.__mro__`

```
AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally.
```

Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`:

addb8e29cd/torch/_dynamo/variables/builtin.py (L1167-L1168)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This commit is contained in:
Philip Meier 2023-11-13 14:29:17 +00:00 committed by PyTorch MergeBot
parent 115da02432
commit d64bc8f0f8
2 changed files with 15 additions and 32 deletions

View File

@ -2851,7 +2851,8 @@ def forward(self, x):
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
gm(torch.randn(3, 4, 5))
def test_access_class_method_from_user_class(self):
@common_utils.parametrize("type_fn", [type, lambda obj: obj.__class__])
def test_access_class_method_from_user_class(self, type_fn):
class A:
@classmethod
def func(cls):
@ -2859,19 +2860,11 @@ def forward(self, x):
def f(x):
a = A()
return x.sum() + type(a).func().sum()
return x.sum() + type_fn(a).func().sum()
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
def f_correct(x):
a = A()
return x.sum() + a.__class__.func().sum()
gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4))
self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4)))
def test_not_functionalize(self):
class Foo(torch.nn.Module):
def __init__(self):

View File

@ -11,7 +11,6 @@ import torch
from torch import sym_float, sym_int
from .. import config, polyfill, variables
from ..allowed_functions import is_allowed
from ..exc import (
AttributeMutationError,
unimplemented,
@ -1177,12 +1176,10 @@ class BuiltinVariable(VariableTracker):
return build_checkpoint_variable(**options)
elif trace_rules.lookup(member) is not None:
return trace_rules.lookup(member)(member, **options)
elif is_allowed(member):
return TorchVariable(member, **options)
elif ConstantVariable.is_literal(member):
return ConstantVariable.create(member, **options)
else:
elif source is not None:
return VariableBuilder(tx, source)(member)
else:
return SourcelessBuilder()(tx, member)
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
member = obj.value.__dict__[name]
@ -1278,24 +1275,17 @@ class BuiltinVariable(VariableTracker):
try:
py_type = obj.python_type()
except NotImplementedError:
py_type = None
if istype(obj, variables.TupleVariable):
return BuiltinVariable(py_type)
if py_type is not None and obj.source:
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
if py_type is not None:
return SourcelessBuilder()(tx, py_type)
except NotImplementedError as error:
raise UserError(
UserErrorType.ANTI_PATTERN,
f"Can't call type() on generated custom object {obj}. "
"Please use __class__ instead",
case_name="type_reflection_method",
)
UserErrorType.INVALID_INPUT,
str(error),
case_name="unknown_python_type",
) from None
if obj.source is None:
return SourcelessBuilder()(tx, py_type)
else:
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
def call_reversed(self, tx, obj: VariableTracker):
if obj.has_unpack_var_sequence(tx):