Export should use aot_export_joint_with_descriptors (#165931)

This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs:
1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer.
2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931
Approved by: https://github.com/zhxchen17
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2025-10-27 08:29:19 -07:00 committed by PyTorch MergeBot
parent f6951cb8ea
commit 6096c0fc74
7 changed files with 167 additions and 56 deletions

View File

@ -13910,6 +13910,7 @@ def forward(self, x, b_t, y):
inps = (torch.ones(5),)
ep = torch.export.export(M(), inps).run_decompositions({})
if IS_FBCODE:
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
@ -13918,6 +13919,17 @@ def forward(self, x):
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, getitem_3, cos_1)""",
)
else:
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, x):
cos = torch.ops.aten.cos.default(x)
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [x, cos]); x = cos = None
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, getitem_3, cos_1)""",
)
@ -13932,6 +13944,7 @@ def forward(self, x):
inps = (torch.ones(5),)
ep = torch.export.export(M(), inps).run_decompositions()
if IS_FBCODE:
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
@ -13941,6 +13954,18 @@ def forward(self, x):
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1); cos = cos_1 = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
return (cos_2,)""",
)
else:
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
def forward(self, x):
cos = torch.ops.aten.cos.default(x)
cos_1 = torch.ops.aten.cos.default(x); x = None
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [cos, cos_1]); cos = cos_1 = None
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
return (cos_2,)""",
)
@ -15338,6 +15363,7 @@ graph():
decomp_table,
)
if IS_FBCODE:
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
@ -15347,6 +15373,18 @@ def forward(self, x):
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = foo_functional, z = cos); foo_functional = cos = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, cos_1)""",
)
else:
self.assertExpectedInline(
str(ep.graph_module.code).strip(),
"""\
def forward(self, x):
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
cos = torch.ops.aten.cos.default(foo_functional)
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [foo_functional, cos]); foo_functional = cos = None
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, cos_1)""",
)

View File

@ -10,6 +10,7 @@ behavior, including:
import sys
from typing import Any, TYPE_CHECKING
from torch._environment import is_fbcode
from torch.utils._config_module import install_config_module
@ -27,6 +28,11 @@ detect_non_strict_fake_tensor_leaks = False
# that we don't know how to proxy, resulting in untracked fake tensors
error_on_lifted_constant_tensors = True
# enable auto_functionalized_v2 in export
# We turn this off in fbcode due to downstream users not
# being ready to handle auto_functionalized_v2.
enable_auto_functionalized_v2_for_export = not is_fbcode()
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403

View File

@ -166,9 +166,6 @@ def run_functionalized_fw_and_collect_metadata(
# Note: this is guaranteed to be set when running under dynamo
static_input_indices: Optional[list[int]] = None,
pre_dispatch: bool = False,
# is_export is technically only needed to avoid using functionalization V2
# during analysis
is_export: bool = False,
) -> Callable[..., ViewAndMutationMeta]:
memo: dict[Tensor, Tensor] = {}
@ -200,7 +197,7 @@ def run_functionalized_fw_and_collect_metadata(
# It doesn't matter if we run this under predispatch or not because it is
# only for figuring out metadata
mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export)
mode = FunctionalTensorMode(_allow_token_discovery=True)
suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env):

View File

@ -41,7 +41,7 @@ def process_inputs(
return x
source = ConstantSource(f"sym_{idx}")
return shape_env.create_symintnode(
shape_env.create_symbol(x, source),
shape_env.create_symbol(x, source, positive=x >= 0),
hint=x,
source=source,
)

View File

@ -573,7 +573,6 @@ def create_aot_state(
keep_input_mutations=aot_config.keep_inference_input_mutations,
is_train=needs_autograd,
pre_dispatch=aot_config.pre_dispatch,
is_export=aot_config.is_export,
)(*_dup_fake_script_obj(fake_flat_args))
req_subclass_dispatch = requires_subclass_dispatch(
@ -905,6 +904,7 @@ def prepare_aot_module_simplified(
*,
force_non_lazy_backward_lowering: bool = False,
disable_functionalization: bool = False,
_record_nn_module_stack: bool = False,
):
if not flatten:
assert kwargs is None
@ -931,7 +931,13 @@ def prepare_aot_module_simplified(
# NB: This doesn't change the in/out convention, except adding the
# parameters as explicit arguments
functional_call = create_functional_call(
mod, params_buffers_spec, params_len + buffers_len, strict_out_tuple=not flatten
mod,
params_buffers_spec,
params_len + buffers_len,
strict_out_tuple=not flatten,
# We need this for export to run ModuleStackTracer
# instead of PythonKeyTracer
store_orig_mod=_record_nn_module_stack,
)
full_args = [*params_flat, *buffers_flat, *args]
@ -1175,6 +1181,7 @@ def aot_export_joint_with_descriptors(
keep_inference_input_mutations=False,
ignore_shape_env=False,
disable_functionalization=False,
_record_nn_module_stack=False,
) -> JointWithDescriptors:
"""
This API captures the joint graph for an nn.Module. However, unlike
@ -1265,6 +1272,7 @@ def aot_export_joint_with_descriptors(
# context.
force_non_lazy_backward_lowering=True,
disable_functionalization=disable_functionalization,
_record_nn_module_stack=_record_nn_module_stack,
)
# TODO: Maybe this should be in create_aot_state? Not sure, that would

View File

@ -145,7 +145,7 @@ class FunctionalTensor(torch.Tensor):
out.elem = elem
if (
not mode.export
torch._export.config.enable_auto_functionalized_v2_for_export
and torch.is_inference_mode_enabled()
and torch._inductor.config.enable_auto_functionalized_v2
):
@ -449,13 +449,19 @@ class FunctionalTensorMode(TorchDispatchMode):
) and not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
):
import torch._export.config as export_config
import torch._inductor.config as inductor_config
if self.export or not inductor_config.enable_auto_functionalized_v2:
return do_auto_functionalize(self, func, args, kwargs)
else:
if torch.compiler.is_exporting():
if export_config.enable_auto_functionalized_v2_for_export:
return do_auto_functionalize_v2(self, func, args, kwargs)
return do_auto_functionalize(self, func, args, kwargs)
if inductor_config.enable_auto_functionalized_v2:
return do_auto_functionalize_v2(self, func, args, kwargs)
return do_auto_functionalize(self, func, args, kwargs)
from torch._higher_order_ops.effects import handle_effects, has_effects
if has_effects(func, args, kwargs):

View File

@ -9,7 +9,7 @@ import sys
import time
import warnings
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from contextlib import contextmanager, ExitStack, nullcontext
from itertools import chain
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -67,7 +67,7 @@ from torch._functorch._aot_autograd.utils import (
)
from torch._functorch.aot_autograd import (
_detect_attribute_assignment,
aot_export_module,
aot_export_joint_with_descriptors,
)
from torch._guards import detect_fake_mode, tracing, TracingContext
from torch._library.fake_class_registry import FakeScriptObject
@ -855,6 +855,54 @@ def _export_to_torch_ir(
return gm_torch_level
def _aot_export_joint_with_descriptors(
stack,
mod,
args,
*,
kwargs,
decompositions,
fake_params_buffers,
_record_nn_module_stack=True,
):
from torch._functorch._aot_autograd.graph_compile import aot_stage2_export
from torch._functorch._aot_autograd.input_output_analysis import (
create_graph_signature,
)
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
mod,
args,
kwargs=kwargs,
decompositions=decompositions,
_record_nn_module_stack=_record_nn_module_stack,
)
# Convert JointWithDescriptors to graph module and ViewAndMutationMeta
gm, fw_metadata = aot_stage2_export(
joint_with_descriptors._aot_state,
joint_with_descriptors._aot_graph_capture,
)
assert isinstance(gm, torch.fx.GraphModule)
# Create GraphSignature from the metadata
graph_signature = create_graph_signature(
gm,
fw_metadata,
joint_with_descriptors.in_spec,
joint_with_descriptors.out_spec,
user_args_flat=pytree.tree_leaves((args, kwargs)),
params_and_buffers_flat=list(fake_params_buffers.values()),
param_names=joint_with_descriptors.params_spec,
buffer_names=joint_with_descriptors.buffers_spec,
trace_joint=False,
num_user_fw_outs=None,
loss_index=None,
)
return gm, graph_signature
def _export_to_aten_ir(
mod: torch.nn.Module,
fake_args,
@ -877,25 +925,29 @@ def _export_to_aten_ir(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with (
with ExitStack() as stack:
stack.enter_context(
torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
),
_ignore_backend_decomps(),
_compiling_state_context(),
custom_triton_ops_decomposition_ctx(),
):
gm, graph_signature = transform(aot_export_module)(
)
)
stack.enter_context(_ignore_backend_decomps())
stack.enter_context(_compiling_state_context())
stack.enter_context(custom_triton_ops_decomposition_ctx())
stack.enter_context(torch.no_grad())
gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
stack,
mod,
fake_args,
trace_joint=False,
pre_dispatch=pre_dispatch,
decompositions=decomp_table,
kwargs=fake_kwargs,
decompositions=decomp_table,
fake_params_buffers=fake_params_buffers,
_record_nn_module_stack=True,
)
def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
@ -1578,7 +1630,7 @@ def _export_to_aten_ir_make_fx(
produce_guards_callback=None,
transform=lambda x: x,
) -> ATenExportArtifact:
def _make_fx_helper(mod, args, kwargs, **flags):
def _make_fx_helper(stack, mod, args, kwargs, **flags):
kwargs = kwargs or {}
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
@ -1796,18 +1848,20 @@ def _export_to_aten_ir_make_fx(
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
with (
with ExitStack() as stack:
stack.enter_context(
torch.nn.utils.stateless._reparametrize_module(
mod,
fake_params_buffers,
tie_weights=True,
strict=True,
stack_weights=True,
),
_ignore_backend_decomps(),
_compiling_state_context(),
):
)
)
stack.enter_context(_ignore_backend_decomps())
stack.enter_context(_compiling_state_context())
gm, graph_signature = transform(_make_fx_helper)(
stack,
mod,
fake_args,
trace_joint=False,
@ -1892,7 +1946,7 @@ def _non_strict_export(
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
def _tuplify_outputs(aot_export):
def _aot_export_non_strict(mod, args, kwargs=None, **flags):
def _aot_export_non_strict(stack, mod, args, *, kwargs=None, **flags):
kwargs = kwargs or {}
class Wrapper(torch.nn.Module):
@ -1936,7 +1990,7 @@ def _non_strict_export(
wrapped_mod, new_preserved_call_signatures, module_call_specs
)
with ctx:
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
log.debug("Exported program from AOTAutograd:\n%s", gm)
sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
@ -2016,7 +2070,9 @@ def _non_strict_export(
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
_override_builtin_ops(),
):
aten_export_artifact = _to_aten_func( # type: ignore[operator]
# _to_aten_func is _export_to_aten_ir when using the default non-strict export
# We need to pass positional args correctly
aten_export_artifact = _to_aten_func(
patched_mod,
new_fake_args,
new_fake_kwargs,