mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotd] Fix freezing API for subclasses (#136265)
Original issue: https://github.com/pytorch/ao/issues/890 The problem: TracingContext.flat_params contain original params, with not desugared Subclasses. While inductor.freezing API works on aot graphs, which already desugared Subclasses. flat_params are used only for this logic and storing in them desguared subclasses fixes the issue. Testing: ``` python test/functorch/test_aotdispatch.py -k test_inductor_freezing_with_subclasses ``` Torch AO original failure: ``` python test/integration/test_integration.py -k test_int8_weight_only_quant_with_freeze ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/136265 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
f048569c24
commit
342c031f0e
|
|
@ -5969,6 +5969,26 @@ class TestAOTModuleSimplified(AOTTestCase):
|
|||
out = torch.compile(fn, backend="aot_eager", fullgraph=True)(inp)
|
||||
self.assertEqual(ref_out, out)
|
||||
|
||||
@torch._inductor.config.patch({"freezing": True})
|
||||
def test_inductor_freezing_with_subclasses(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = TwoTensor(torch.randn(3, 4), torch.randn(3, 4))
|
||||
|
||||
def forward(self, x):
|
||||
return (
|
||||
x.index_select(
|
||||
dim=0, index=torch.tensor([0, 2, 1], dtype=torch.int64)
|
||||
)
|
||||
+ self.w
|
||||
)
|
||||
|
||||
m = M()
|
||||
inp = torch.randn(3, 4)
|
||||
with torch.no_grad():
|
||||
torch.compile(m, fullgraph=True)(inp)
|
||||
|
||||
|
||||
# entries in here don't work and need to be fixed.
|
||||
# Each one of these is a bug (or needs to be investigated)
|
||||
|
|
|
|||
|
|
@ -136,6 +136,18 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
|
|||
return unwrapped_args
|
||||
|
||||
|
||||
def unwrap_tensor_subclasses_with_indices_to_original(wrapped_args):
|
||||
ret_unwrapped = []
|
||||
ret_indices_to_original = []
|
||||
for i, a in enumerate(wrapped_args):
|
||||
a_unwrapped = unwrap_tensor_subclasses([a], is_joint_structure=False)
|
||||
ret_unwrapped.extend(a_unwrapped)
|
||||
n = len(a_unwrapped)
|
||||
ret_indices_to_original.extend([i] * n)
|
||||
|
||||
return ret_unwrapped, ret_indices_to_original
|
||||
|
||||
|
||||
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
|
||||
static_input_indices = set(static_input_indices)
|
||||
new_ind = 0
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from torch._subclasses import FakeTensor, FakeTensorMode
|
|||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils.weak import TensorWeakRef
|
||||
|
||||
|
||||
static_inputs_log = torch._logging.getArtifactLogger(
|
||||
|
|
@ -99,6 +100,7 @@ from ._aot_autograd.subclass_utils import ( # noqa: F401
|
|||
create_metadata_for_subclass,
|
||||
requires_subclass_dispatch,
|
||||
unwrap_tensor_subclasses,
|
||||
unwrap_tensor_subclasses_with_indices_to_original,
|
||||
wrap_tensor_subclasses,
|
||||
wrap_tensor_subclasses_maybe_joint,
|
||||
)
|
||||
|
|
@ -978,6 +980,13 @@ def aot_module_simplified(
|
|||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
tracing_context.params_flat = params_flat
|
||||
(
|
||||
params_flat_unwrap_subclasses,
|
||||
tracing_context.params_unwrapped_to_flat_index,
|
||||
) = unwrap_tensor_subclasses_with_indices_to_original(params_flat)
|
||||
tracing_context.params_flat_unwrap_subclasses = [
|
||||
TensorWeakRef(p) for p in params_flat_unwrap_subclasses
|
||||
]
|
||||
|
||||
aot_autograd_arg_pos_to_source = None
|
||||
# Then, the params 1:1 mapped sources, if relevant.
|
||||
|
|
|
|||
|
|
@ -643,6 +643,8 @@ class TracingContext:
|
|||
# this is only set after aot_autograd
|
||||
self.aot_graph_name = None
|
||||
self.params_flat = None
|
||||
self.params_flat_unwrap_subclasses = None
|
||||
self.params_unwrapped_to_flat_index = None
|
||||
# this is for extended return calling convention from backend
|
||||
# compiler to aot_autograd
|
||||
# Per output, what the compiler specified stride of the output is,
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ from torch._dynamo.utils import (
|
|||
lazy_format_graph_code,
|
||||
)
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch._aot_autograd.subclass_utils import unwrap_tensor_subclasses
|
||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
||||
from torch._inductor.codecache import (
|
||||
_StrideExprStr,
|
||||
|
|
@ -1186,14 +1187,40 @@ def fw_compiler_freezing(
|
|||
)
|
||||
|
||||
static_input_idxs = list(range(num_fixed))
|
||||
wrapper_new_args_unwrapped_indices: List[int] = []
|
||||
# constant params will be real tensors, not fake
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
unwrapped_args_offsets = [0]
|
||||
max_offset_idx = 0
|
||||
if tracing_context is not None:
|
||||
params_flat = tracing_context.params_flat
|
||||
assert params_flat is not None
|
||||
for i in range(len(params_flat)):
|
||||
assert tracing_context.params_flat_unwrap_subclasses is not None
|
||||
params_flat_unwrap = [
|
||||
r() for r in tracing_context.params_flat_unwrap_subclasses
|
||||
]
|
||||
assert params_flat_unwrap is not None
|
||||
max_offset_idx = max(0, len(params_flat_unwrap) - 1)
|
||||
assert params_flat_unwrap is not None
|
||||
preserved_indices_params_flat = set()
|
||||
unwrapped_idxs = tracing_context.params_unwrapped_to_flat_index
|
||||
assert unwrapped_idxs is not None
|
||||
current_offset = 0
|
||||
if len(params_flat_unwrap) > 0:
|
||||
unwrapped_args_offsets = []
|
||||
|
||||
for i in range(len(params_flat_unwrap)):
|
||||
if i not in preserved_arg_indices:
|
||||
params_flat[i] = None
|
||||
params_flat_unwrap[i] = None
|
||||
if i > 0 and unwrapped_idxs[i] == unwrapped_idxs[i - 1]:
|
||||
current_offset += 1
|
||||
else:
|
||||
preserved_indices_params_flat.add(unwrapped_idxs[i])
|
||||
unwrapped_args_offsets.append(current_offset)
|
||||
|
||||
# Deallocate wrapped params, if all subelements were deallocated
|
||||
assert tracing_context.params_flat is not None
|
||||
for i in range(len(tracing_context.params_flat)):
|
||||
if i not in preserved_indices_params_flat:
|
||||
tracing_context.params_flat[i] = None
|
||||
|
||||
if tracing_context.fw_metadata:
|
||||
static_input_idxs += tracing_context.fw_metadata.static_input_indices
|
||||
|
|
@ -1217,7 +1244,12 @@ def fw_compiler_freezing(
|
|||
return optimized_function
|
||||
|
||||
def wrapper(args):
|
||||
args_new = [args[i] for i in preserved_arg_indices]
|
||||
args_unwrapped = unwrap_tensor_subclasses(args, is_joint_structure=False)
|
||||
args_new = [
|
||||
args_unwrapped[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
|
||||
for i in preserved_arg_indices
|
||||
]
|
||||
args_unwrapped.clear()
|
||||
args.clear()
|
||||
return optimized_function(args_new)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,8 @@ def freeze(
|
|||
|
||||
if tracing_context := torch._guards.TracingContext.try_get():
|
||||
fw_metadata = tracing_context.fw_metadata
|
||||
params_flat = tracing_context.params_flat
|
||||
assert tracing_context.params_flat_unwrap_subclasses is not None
|
||||
params_flat = [r() for r in tracing_context.params_flat_unwrap_subclasses]
|
||||
assert fw_metadata is not None and params_flat is not None
|
||||
|
||||
preserved_arg_indices = replace_params_with_constants(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user