mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[subclasses] Do not fakeTensor const prop subclass args (#134855)
The issue: Const propagation checks only if arguments do not have FakeTensor. If argument is Subclass, it will pass this condition. As a result Const Propogation execution happens without FakeTensorMode and having tensor factories inside Subclass.__torch_dispatch__ results that this Tensor is not Fakified. Solution: If we have subclasses arguments, do not count that const propagation is doable Pull Request resolved: https://github.com/pytorch/pytorch/pull/134855 Approved by: https://github.com/zou3519
This commit is contained in:
parent
2a49296d75
commit
33ba952e31
|
|
@ -6010,6 +6010,18 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
with self.assertRaisesRegex(AssertionError, "Unexpected fake"):
|
||||
aot_module_simplified(MockModule(), (fake_x,), nop)
|
||||
|
||||
def test_aot_test_subclasses_with_tensor_factories(self):
|
||||
from torch.testing._internal.common_subclass import SubclassWithTensorFactory
|
||||
|
||||
inp = SubclassWithTensorFactory(torch.zeros(3, 5))
|
||||
|
||||
def fn(x):
|
||||
return 2 * x
|
||||
|
||||
ref_out = fn(inp)
|
||||
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
|
||||
self.assertEqual(ref_out, out)
|
||||
|
||||
|
||||
# entries in here don't work and need to be fixed.
|
||||
# Each one of these is a bug (or needs to be investigated)
|
||||
|
|
|
|||
|
|
@ -1719,6 +1719,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
has_symbolic_sizes = any(
|
||||
i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
|
||||
) or any(isinstance(a, SymInt) for a in flat_args)
|
||||
has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args)
|
||||
|
||||
converter = self.fake_tensor_converter
|
||||
|
||||
|
|
@ -1736,6 +1737,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
should_allow_numbers_as_tensors(func)
|
||||
and not has_symbolic_sizes
|
||||
and not flat_arg_fake_tensors
|
||||
and not has_subclasses
|
||||
):
|
||||
assert all(
|
||||
t.constant is not None for t in flat_arg_fake_tensors
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@
|
|||
import torch
|
||||
from copy import deepcopy
|
||||
from torch.utils._pytree import tree_map
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
# TODO: Move LoggingTensor here.
|
||||
from torch.testing._internal.logging_tensor import LoggingTensor
|
||||
|
|
@ -216,3 +218,49 @@ subclass_db = {
|
|||
closed_under_ops=False # sparse semantics
|
||||
),
|
||||
}
|
||||
|
||||
class SubclassWithTensorFactory(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, src):
|
||||
shape = src.shape
|
||||
kwargs = {}
|
||||
kwargs["strides"] = src.stride()
|
||||
kwargs["storage_offset"] = src.storage_offset()
|
||||
kwargs["device"] = src.device
|
||||
kwargs["layout"] = src.layout
|
||||
kwargs["requires_grad"] = src.requires_grad
|
||||
kwargs["dtype"] = src.dtype
|
||||
out = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs)
|
||||
return out
|
||||
|
||||
def __init__(self, src):
|
||||
self.src = src
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}"
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
return ["src"], None
|
||||
|
||||
@classmethod
|
||||
def __tensor_unflatten__(cls, inner_tensors, meta, outer_size, outer_stride):
|
||||
src = inner_tensors["src"]
|
||||
return cls(src)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
def _fn(x):
|
||||
return x.src * torch.ones(x.src.shape) if x.src.dtype == torch.float32 else x.src
|
||||
|
||||
_args = pytree.tree_map_only(cls, _fn, args)
|
||||
_kwargs = pytree.tree_map_only(cls, _fn, kwargs)
|
||||
|
||||
_out = func(*_args, **_kwargs)
|
||||
|
||||
_out_flat, _out_spec = pytree.tree_flatten(_out)
|
||||
|
||||
out_flat = [cls(o) if isinstance(o, torch.Tensor) else o for o in _out_flat]
|
||||
return pytree.tree_unflatten(out_flat, _out_spec)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user