mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
use sourceless builder for builtin getattr (#113340)
In TorchVision we use the following (simplified) dispatch mechanism:
```python
import torch
def kernel1(tensor):
return tensor + 2
def dispatcher1(input):
kernel = get_kernel(dispatcher1, type(input))
return kernel(input)
def kernel2(tensor):
return tensor - 2
def dispatcher2(input):
kernel = get_kernel(dispatcher2, type(input))
return kernel(input)
# We actually use the function and type as keys, rather than their names.
# However, this currently not supported, but should be easy to add after
# https://github.com/pytorch/pytorch/pull/111196
REGISTRY = {
"dispatcher1": {"Tensor": kernel1},
"dispatcher2": {"Tensor": kernel2},
}
def get_kernel(dispatcher, input_type):
dispatcher_registry = REGISTRY[dispatcher.__name__]
for cls in input_type.__mro__:
kernel = dispatcher_registry[cls.__name__]
break
return kernel
```
This can be compiled without graph breaks:
```python
cfn = torch.compile(dispatcher1, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 5)
cfn = torch.compile(dispatcher2, fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 1)
```
However, if we start chaining these calls, we hit some issues:
```python
class Pipeline(torch.nn.Module):
def forward(self, input):
input = dispatcher1(input)
input = dispatcher2(input)
return input
cfn = torch.compile(Pipeline(), fullgraph=True)
torch.testing.assert_close(int(cfn(torch.tensor(3))), 3)
```
```
Can't access members of type(obj) for a generated custom object. Please use __class__ instead
```
The error message is not really helpful here. The following happens: when compiling `dispatcher1`, `get_kernel` gets inlined. That means when hitting `dispatcher2`, the `type` call no longer happens on an input with a source. Thus, in the first iteration we hit the top branch, while in the second we hit the bottom:
addb8e29cd/torch/_dynamo/variables/builtin.py (L1264-L1268)
And the error message I posted above originates from the type being treated as constant. This PR replaces this with a `SourcelessBuilder` instead.
With that fix in place, we hit another pointing to `input_type.__mro__`
```
AssertionError: Consider SourcelessBuilder for ephemeral objects, usually objects created locally.
```
Fix is similar: instead of using a `VariableBuilder` here, we use a `SourcelessBuilder` in case we have no `source`:
addb8e29cd/torch/_dynamo/variables/builtin.py (L1167-L1168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113340
Approved by: https://github.com/peterbell10, https://github.com/lezcano
This commit is contained in:
parent
115da02432
commit
d64bc8f0f8
|
|
@ -2851,7 +2851,8 @@ def forward(self, x):
|
|||
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
|
||||
gm(torch.randn(3, 4, 5))
|
||||
|
||||
def test_access_class_method_from_user_class(self):
|
||||
@common_utils.parametrize("type_fn", [type, lambda obj: obj.__class__])
|
||||
def test_access_class_method_from_user_class(self, type_fn):
|
||||
class A:
|
||||
@classmethod
|
||||
def func(cls):
|
||||
|
|
@ -2859,19 +2860,11 @@ def forward(self, x):
|
|||
|
||||
def f(x):
|
||||
a = A()
|
||||
return x.sum() + type(a).func().sum()
|
||||
return x.sum() + type_fn(a).func().sum()
|
||||
|
||||
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.ones(6, 4))
|
||||
self.assertEqual(f(torch.ones(6, 4)), gm(torch.ones(6, 4)))
|
||||
|
||||
def f_correct(x):
|
||||
a = A()
|
||||
return x.sum() + a.__class__.func().sum()
|
||||
|
||||
gm, _ = torch._dynamo.export(f_correct, aten_graph=True)(torch.ones(6, 4))
|
||||
|
||||
self.assertEqual(f_correct(torch.ones(6, 4)), gm(torch.ones(6, 4)))
|
||||
|
||||
def test_not_functionalize(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import torch
|
|||
from torch import sym_float, sym_int
|
||||
|
||||
from .. import config, polyfill, variables
|
||||
from ..allowed_functions import is_allowed
|
||||
from ..exc import (
|
||||
AttributeMutationError,
|
||||
unimplemented,
|
||||
|
|
@ -1177,12 +1176,10 @@ class BuiltinVariable(VariableTracker):
|
|||
return build_checkpoint_variable(**options)
|
||||
elif trace_rules.lookup(member) is not None:
|
||||
return trace_rules.lookup(member)(member, **options)
|
||||
elif is_allowed(member):
|
||||
return TorchVariable(member, **options)
|
||||
elif ConstantVariable.is_literal(member):
|
||||
return ConstantVariable.create(member, **options)
|
||||
else:
|
||||
elif source is not None:
|
||||
return VariableBuilder(tx, source)(member)
|
||||
else:
|
||||
return SourcelessBuilder()(tx, member)
|
||||
elif isinstance(obj, (PythonModuleVariable, DummyModule)):
|
||||
member = obj.value.__dict__[name]
|
||||
|
||||
|
|
@ -1278,24 +1275,17 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
try:
|
||||
py_type = obj.python_type()
|
||||
except NotImplementedError:
|
||||
py_type = None
|
||||
except NotImplementedError as error:
|
||||
raise UserError(
|
||||
UserErrorType.INVALID_INPUT,
|
||||
str(error),
|
||||
case_name="unknown_python_type",
|
||||
) from None
|
||||
|
||||
if istype(obj, variables.TupleVariable):
|
||||
return BuiltinVariable(py_type)
|
||||
|
||||
if py_type is not None and obj.source:
|
||||
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
|
||||
|
||||
if py_type is not None:
|
||||
if obj.source is None:
|
||||
return SourcelessBuilder()(tx, py_type)
|
||||
|
||||
raise UserError(
|
||||
UserErrorType.ANTI_PATTERN,
|
||||
f"Can't call type() on generated custom object {obj}. "
|
||||
"Please use __class__ instead",
|
||||
case_name="type_reflection_method",
|
||||
)
|
||||
else:
|
||||
return VariableBuilder(tx, TypeSource(obj.source))(py_type)
|
||||
|
||||
def call_reversed(self, tx, obj: VariableTracker):
|
||||
if obj.has_unpack_var_sequence(tx):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user