[aotd] Fix freezing API for subclasses (#136265)

Original issue:
https://github.com/pytorch/ao/issues/890

The problem:

TracingContext.flat_params contain original params, with not desugared Subclasses.
While inductor.freezing API works on aot graphs, which already desugared Subclasses.

flat_params are used only for this logic and storing in them desguared subclasses fixes the issue.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_inductor_freezing_with_subclasses
```
Torch AO original failure:
```
python test/integration/test_integration.py -k test_int8_weight_only_quant_with_freeze
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136265
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-09-23 14:03:05 -07:00 committed by PyTorch MergeBot
parent f048569c24
commit 342c031f0e
6 changed files with 82 additions and 6 deletions

View File

@ -5969,6 +5969,26 @@ class TestAOTModuleSimplified(AOTTestCase):
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
self.assertEqual(ref_out, out)
@torch._inductor.config.patch({"freezing": True})
def test_inductor_freezing_with_subclasses(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = TwoTensor(torch.randn(3, 4), torch.randn(3, 4))
def forward(self, x):
return (
x.index_select(
dim=0, index=torch.tensor([0, 2, 1], dtype=torch.int64)
)
+ self.w
)
m = M()
inp = torch.randn(3, 4)
with torch.no_grad():
torch.compile(m, fullgraph=True)(inp)
# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)

View File

@ -136,6 +136,18 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
return unwrapped_args
def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args):
ret_unwrapped = []
ret_indices_to_original = []
for i, a in enumerate(wrapped_args):
a_unwrapped = unwrap_tensor_subclasses([a], is_joint_structure=False)
ret_unwrapped.extend(a_unwrapped)
n = len(a_unwrapped)
ret_indices_to_original.extend([i] * n)
return ret_unwrapped, ret_indices_to_original
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
static_input_indices = set(static_input_indices)
new_ind = 0

View File

@ -22,6 +22,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils.weak import TensorWeakRef
static_inputs_log = torch._logging.getArtifactLogger(
@ -99,6 +100,7 @@ from ._aot_autograd.subclass_utils import ( # noqa: F401
create_metadata_for_subclass,
requires_subclass_dispatch,
unwrap_tensor_subclasses,
unwrap_tensor_subclasses_with_indices_to_original,
wrap_tensor_subclasses,
wrap_tensor_subclasses_maybe_joint,
)
@ -978,6 +980,13 @@ def aot_module_simplified(
if tracing_context := torch._guards.TracingContext.try_get():
tracing_context.params_flat = params_flat
(
params_flat_unwrap_subclasses,
tracing_context.params_unwrapped_to_flat_index,
) = unwrap_tensor_subclasses_with_indices_to_original(params_flat)
tracing_context.params_flat_unwrap_subclasses = [
TensorWeakRef(p) for p in params_flat_unwrap_subclasses
]
aot_autograd_arg_pos_to_source = None
# Then, the params 1:1 mapped sources, if relevant.

View File

@ -643,6 +643,8 @@ class TracingContext:
# this is only set after aot_autograd
self.aot_graph_name = None
self.params_flat = None
self.params_flat_unwrap_subclasses = None
self.params_unwrapped_to_flat_index = None
# this is for extended return calling convention from backend
# compiler to aot_autograd
# Per output, what the compiler specified stride of the output is,

View File

@ -33,6 +33,7 @@ from torch._dynamo.utils import (
lazy_format_graph_code,
)
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.subclass_utils import unwrap_tensor_subclasses
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._inductor.codecache import (
_StrideExprStr,
@ -1186,14 +1187,40 @@ def fw_compiler_freezing(
)
static_input_idxs = list(range(num_fixed))
wrapper_new_args_unwrapped_indices: List[int] = []
# constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get()
unwrapped_args_offsets = [0]
max_offset_idx = 0
if tracing_context is not None:
params_flat = tracing_context.params_flat
assert params_flat is not None
for i in range(len(params_flat)):
assert tracing_context.params_flat_unwrap_subclasses is not None
params_flat_unwrap = [
r() for r in tracing_context.params_flat_unwrap_subclasses
]
assert params_flat_unwrap is not None
max_offset_idx = max(0, len(params_flat_unwrap) - 1)
assert params_flat_unwrap is not None
preserved_indices_params_flat = set()
unwrapped_idxs = tracing_context.params_unwrapped_to_flat_index
assert unwrapped_idxs is not None
current_offset = 0
if len(params_flat_unwrap) > 0:
unwrapped_args_offsets = []
for i in range(len(params_flat_unwrap)):
if i not in preserved_arg_indices:
params_flat[i] = None
params_flat_unwrap[i] = None
if i > 0 and unwrapped_idxs[i] == unwrapped_idxs[i - 1]:
current_offset += 1
else:
preserved_indices_params_flat.add(unwrapped_idxs[i])
unwrapped_args_offsets.append(current_offset)
# Deallocate wrapped params, if all subelements were deallocated
assert tracing_context.params_flat is not None
for i in range(len(tracing_context.params_flat)):
if i not in preserved_indices_params_flat:
tracing_context.params_flat[i] = None
if tracing_context.fw_metadata:
static_input_idxs += tracing_context.fw_metadata.static_input_indices
@ -1217,7 +1244,12 @@ def fw_compiler_freezing(
return optimized_function
def wrapper(args):
args_new = [args[i] for i in preserved_arg_indices]
args_unwrapped = unwrap_tensor_subclasses(args, is_joint_structure=False)
args_new = [
args_unwrapped[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
for i in preserved_arg_indices
]
args_unwrapped.clear()
args.clear()
return optimized_function(args_new)

View File

@ -90,7 +90,8 @@ def freeze(
if tracing_context := torch._guards.TracingContext.try_get():
fw_metadata = tracing_context.fw_metadata
params_flat = tracing_context.params_flat
assert tracing_context.params_flat_unwrap_subclasses is not None
params_flat = [r() for r in tracing_context.params_flat_unwrap_subclasses]
assert fw_metadata is not None and params_flat is not None
preserved_arg_indices = replace_params_with_constants(