[subclasses] Do not fakeTensor const prop subclass args (#134855)

The issue:

Const propagation checks only if arguments do not have FakeTensor. If argument is Subclass, it will pass this condition.

As a result Const Propogation execution happens without FakeTensorMode and having tensor factories inside Subclass.__torch_dispatch__ results that this Tensor is not Fakified.

Solution:

If we have subclasses arguments, do not count that const propagation is doable

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134855
Approved by: https://github.com/zou3519
This commit is contained in:
IvanKobzarev 2024-08-30 07:49:06 -07:00 committed by PyTorch MergeBot
parent 2a49296d75
commit 33ba952e31
3 changed files with 62 additions and 0 deletions

View File

@ -6010,6 +6010,18 @@ class TestAOTModuleSimplified(AOTTestCase):
with self.assertRaisesRegex(AssertionError, "Unexpected fake"):
aot_module_simplified(MockModule(), (fake_x,), nop)
def test_aot_test_subclasses_with_tensor_factories(self):
from torch.testing._internal.common_subclass import SubclassWithTensorFactory
inp = SubclassWithTensorFactory(torch.zeros(3, 5))
def fn(x):
return 2 * x
ref_out = fn(inp)
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
self.assertEqual(ref_out, out)
# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)

View File

@ -1719,6 +1719,7 @@ class FakeTensorMode(TorchDispatchMode):
has_symbolic_sizes = any(
i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
) or any(isinstance(a, SymInt) for a in flat_args)
has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args)
converter = self.fake_tensor_converter
@ -1736,6 +1737,7 @@ class FakeTensorMode(TorchDispatchMode):
should_allow_numbers_as_tensors(func)
and not has_symbolic_sizes
and not flat_arg_fake_tensors
and not has_subclasses
):
assert all(
t.constant is not None for t in flat_arg_fake_tensors

View File

@ -3,6 +3,8 @@
import torch
from copy import deepcopy
from torch.utils._pytree import tree_map
import torch.utils._pytree as pytree
# TODO: Move LoggingTensor here.
from torch.testing._internal.logging_tensor import LoggingTensor
@ -216,3 +218,49 @@ subclass_db = {
closed_under_ops=False # sparse semantics
),
}
class SubclassWithTensorFactory(torch.Tensor):
@staticmethod
def __new__(cls, src):
shape = src.shape
kwargs = {}
kwargs["strides"] = src.stride()
kwargs["storage_offset"] = src.storage_offset()
kwargs["device"] = src.device
kwargs["layout"] = src.layout
kwargs["requires_grad"] = src.requires_grad
kwargs["dtype"] = src.dtype
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
return out
def __init__(self, src):
self.src = src
def __repr__(self):
return f"{self.__class__.__name__}"
def __tensor_flatten__(self):
return ["src"], None
@classmethod
def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
src = inner_tensors["src"]
return cls(src)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if kwargs is None:
kwargs = {}
def _fn(x):
return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
_args = pytree.tree_map_only(cls, _fn, args)
_kwargs = pytree.tree_map_only(cls, _fn, kwargs)
_out = func(*_args, **_kwargs)
_out_flat, _out_spec = pytree.tree_flatten(_out)
out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
return pytree.tree_unflatten(out_flat, _out_spec)