mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Export should use aot_export_joint_with_descriptors (#165931)
This diff moves export run_decompositions to use aot_export_joint_with_descriptors instead of aot_export_module. Doing so, i ran into 2 main bugs: 1) aot_export_joint_with_descriptors don't correctly pass in record_nn_module_stack flag that is needed to populate nn_module_stack by switching the internal tracer. 2) When creating symint with negative inputs, we need to pass in positive=False. This didn't matter before because aot_autograd directly returns integer inputs instead of creating symint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165931 Approved by: https://github.com/zhxchen17
This commit is contained in:
parent
f6951cb8ea
commit
6096c0fc74
|
|
@ -13910,16 +13910,28 @@ def forward(self, x, b_t, y):
|
|||
inps = (torch.ones(5),)
|
||||
|
||||
ep = torch.export.export(M(), inps).run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
if IS_FBCODE:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
auto_functionalized = torch.ops.higher_order.auto_functionalized(torch.ops.testlib.foo.default, x = x, z = cos); x = cos = None
|
||||
getitem_3 = auto_functionalized[3]; auto_functionalized = None
|
||||
cos_1 = torch.ops.aten.cos.default(getitem_3)
|
||||
return (getitem_3, getitem_3, cos_1)""",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [x, cos]); x = cos = None
|
||||
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
|
||||
cos_1 = torch.ops.aten.cos.default(getitem_3)
|
||||
return (getitem_3, getitem_3, cos_1)""",
|
||||
)
|
||||
|
||||
def test_custom_op_auto_warn_pre_dispatch(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
@ -13932,9 +13944,10 @@ def forward(self, x):
|
|||
inps = (torch.ones(5),)
|
||||
|
||||
ep = torch.export.export(M(), inps).run_decompositions()
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
if IS_FBCODE:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
cos_1 = torch.ops.aten.cos.default(x); x = None
|
||||
|
|
@ -13942,7 +13955,19 @@ def forward(self, x):
|
|||
getitem_3 = auto_functionalized[3]; auto_functionalized = None
|
||||
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
|
||||
return (cos_2,)""",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
cos = torch.ops.aten.cos.default(x)
|
||||
cos_1 = torch.ops.aten.cos.default(x); x = None
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [cos, cos_1]); cos = cos_1 = None
|
||||
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
|
||||
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
|
||||
return (cos_2,)""",
|
||||
)
|
||||
|
||||
ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -15338,9 +15363,10 @@ graph():
|
|||
decomp_table,
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
if IS_FBCODE:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
|
||||
cos = torch.ops.aten.cos.default(foo_functional)
|
||||
|
|
@ -15348,7 +15374,19 @@ def forward(self, x):
|
|||
getitem_3 = auto_functionalized[3]; auto_functionalized = None
|
||||
cos_1 = torch.ops.aten.cos.default(getitem_3)
|
||||
return (getitem_3, cos_1)""",
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code).strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
foo_functional = torch.ops.testlib.foo_functional.default(x); x = None
|
||||
cos = torch.ops.aten.cos.default(foo_functional)
|
||||
auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.testlib.foo.default, _x_base_index = 0, _z_base_index = 1, _all_bases = [foo_functional, cos]); foo_functional = cos = None
|
||||
getitem_3 = auto_functionalized_v2[3]; auto_functionalized_v2 = None
|
||||
cos_1 = torch.ops.aten.cos.default(getitem_3)
|
||||
return (getitem_3, cos_1)""",
|
||||
)
|
||||
|
||||
def test_run_decompositions_keep_metadata(self):
|
||||
"""Make sure the metadata is kept after exported program run_decompositions."""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ behavior, including:
|
|||
import sys
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from torch._environment import is_fbcode
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
|
|
@ -27,6 +28,11 @@ detect_non_strict_fake_tensor_leaks = False
|
|||
# that we don't know how to proxy, resulting in untracked fake tensors
|
||||
error_on_lifted_constant_tensors = True
|
||||
|
||||
# enable auto_functionalized_v2 in export
|
||||
# We turn this off in fbcode due to downstream users not
|
||||
# being ready to handle auto_functionalized_v2.
|
||||
enable_auto_functionalized_v2_for_export = not is_fbcode()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.utils._config_typing import * # noqa: F401, F403
|
||||
|
||||
|
|
|
|||
|
|
@ -166,9 +166,6 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
# Note: this is guaranteed to be set when running under dynamo
|
||||
static_input_indices: Optional[list[int]] = None,
|
||||
pre_dispatch: bool = False,
|
||||
# is_export is technically only needed to avoid using functionalization V2
|
||||
# during analysis
|
||||
is_export: bool = False,
|
||||
) -> Callable[..., ViewAndMutationMeta]:
|
||||
memo: dict[Tensor, Tensor] = {}
|
||||
|
||||
|
|
@ -200,7 +197,7 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
|
||||
# It doesn't matter if we run this under predispatch or not because it is
|
||||
# only for figuring out metadata
|
||||
mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export)
|
||||
mode = FunctionalTensorMode(_allow_token_discovery=True)
|
||||
suppress_pending = contextlib.nullcontext()
|
||||
fake_mode = detect_fake_mode()
|
||||
if fake_mode and (shape_env := fake_mode.shape_env):
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def process_inputs(
|
|||
return x
|
||||
source = ConstantSource(f"sym_{idx}")
|
||||
return shape_env.create_symintnode(
|
||||
shape_env.create_symbol(x, source),
|
||||
shape_env.create_symbol(x, source, positive=x >= 0),
|
||||
hint=x,
|
||||
source=source,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -573,7 +573,6 @@ def create_aot_state(
|
|||
keep_input_mutations=aot_config.keep_inference_input_mutations,
|
||||
is_train=needs_autograd,
|
||||
pre_dispatch=aot_config.pre_dispatch,
|
||||
is_export=aot_config.is_export,
|
||||
)(*_dup_fake_script_obj(fake_flat_args))
|
||||
|
||||
req_subclass_dispatch = requires_subclass_dispatch(
|
||||
|
|
@ -905,6 +904,7 @@ def prepare_aot_module_simplified(
|
|||
*,
|
||||
force_non_lazy_backward_lowering: bool = False,
|
||||
disable_functionalization: bool = False,
|
||||
_record_nn_module_stack: bool = False,
|
||||
):
|
||||
if not flatten:
|
||||
assert kwargs is None
|
||||
|
|
@ -931,7 +931,13 @@ def prepare_aot_module_simplified(
|
|||
# NB: This doesn't change the in/out convention, except adding the
|
||||
# parameters as explicit arguments
|
||||
functional_call = create_functional_call(
|
||||
mod, params_buffers_spec, params_len + buffers_len, strict_out_tuple=not flatten
|
||||
mod,
|
||||
params_buffers_spec,
|
||||
params_len + buffers_len,
|
||||
strict_out_tuple=not flatten,
|
||||
# We need this for export to run ModuleStackTracer
|
||||
# instead of PythonKeyTracer
|
||||
store_orig_mod=_record_nn_module_stack,
|
||||
)
|
||||
|
||||
full_args = [*params_flat, *buffers_flat, *args]
|
||||
|
|
@ -1175,6 +1181,7 @@ def aot_export_joint_with_descriptors(
|
|||
keep_inference_input_mutations=False,
|
||||
ignore_shape_env=False,
|
||||
disable_functionalization=False,
|
||||
_record_nn_module_stack=False,
|
||||
) -> JointWithDescriptors:
|
||||
"""
|
||||
This API captures the joint graph for an nn.Module. However, unlike
|
||||
|
|
@ -1265,6 +1272,7 @@ def aot_export_joint_with_descriptors(
|
|||
# context.
|
||||
force_non_lazy_backward_lowering=True,
|
||||
disable_functionalization=disable_functionalization,
|
||||
_record_nn_module_stack=_record_nn_module_stack,
|
||||
)
|
||||
|
||||
# TODO: Maybe this should be in create_aot_state? Not sure, that would
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class FunctionalTensor(torch.Tensor):
|
|||
out.elem = elem
|
||||
|
||||
if (
|
||||
not mode.export
|
||||
torch._export.config.enable_auto_functionalized_v2_for_export
|
||||
and torch.is_inference_mode_enabled()
|
||||
and torch._inductor.config.enable_auto_functionalized_v2
|
||||
):
|
||||
|
|
@ -449,12 +449,18 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||
) and not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), torch._C.DispatchKey.Functionalize
|
||||
):
|
||||
import torch._export.config as export_config
|
||||
import torch._inductor.config as inductor_config
|
||||
|
||||
if self.export or not inductor_config.enable_auto_functionalized_v2:
|
||||
if torch.compiler.is_exporting():
|
||||
if export_config.enable_auto_functionalized_v2_for_export:
|
||||
return do_auto_functionalize_v2(self, func, args, kwargs)
|
||||
|
||||
return do_auto_functionalize(self, func, args, kwargs)
|
||||
else:
|
||||
|
||||
if inductor_config.enable_auto_functionalized_v2:
|
||||
return do_auto_functionalize_v2(self, func, args, kwargs)
|
||||
return do_auto_functionalize(self, func, args, kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import sys
|
|||
import time
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from itertools import chain
|
||||
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ from torch._functorch._aot_autograd.utils import (
|
|||
)
|
||||
from torch._functorch.aot_autograd import (
|
||||
_detect_attribute_assignment,
|
||||
aot_export_module,
|
||||
aot_export_joint_with_descriptors,
|
||||
)
|
||||
from torch._guards import detect_fake_mode, tracing, TracingContext
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
|
|
@ -855,6 +855,54 @@ def _export_to_torch_ir(
|
|||
return gm_torch_level
|
||||
|
||||
|
||||
def _aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
mod,
|
||||
args,
|
||||
*,
|
||||
kwargs,
|
||||
decompositions,
|
||||
fake_params_buffers,
|
||||
_record_nn_module_stack=True,
|
||||
):
|
||||
from torch._functorch._aot_autograd.graph_compile import aot_stage2_export
|
||||
from torch._functorch._aot_autograd.input_output_analysis import (
|
||||
create_graph_signature,
|
||||
)
|
||||
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
mod,
|
||||
args,
|
||||
kwargs=kwargs,
|
||||
decompositions=decompositions,
|
||||
_record_nn_module_stack=_record_nn_module_stack,
|
||||
)
|
||||
# Convert JointWithDescriptors to graph module and ViewAndMutationMeta
|
||||
gm, fw_metadata = aot_stage2_export(
|
||||
joint_with_descriptors._aot_state,
|
||||
joint_with_descriptors._aot_graph_capture,
|
||||
)
|
||||
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
|
||||
# Create GraphSignature from the metadata
|
||||
graph_signature = create_graph_signature(
|
||||
gm,
|
||||
fw_metadata,
|
||||
joint_with_descriptors.in_spec,
|
||||
joint_with_descriptors.out_spec,
|
||||
user_args_flat=pytree.tree_leaves((args, kwargs)),
|
||||
params_and_buffers_flat=list(fake_params_buffers.values()),
|
||||
param_names=joint_with_descriptors.params_spec,
|
||||
buffer_names=joint_with_descriptors.buffers_spec,
|
||||
trace_joint=False,
|
||||
num_user_fw_outs=None,
|
||||
loss_index=None,
|
||||
)
|
||||
return gm, graph_signature
|
||||
|
||||
|
||||
def _export_to_aten_ir(
|
||||
mod: torch.nn.Module,
|
||||
fake_args,
|
||||
|
|
@ -877,25 +925,29 @@ def _export_to_aten_ir(
|
|||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
with (
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
),
|
||||
_ignore_backend_decomps(),
|
||||
_compiling_state_context(),
|
||||
custom_triton_ops_decomposition_ctx(),
|
||||
):
|
||||
gm, graph_signature = transform(aot_export_module)(
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
)
|
||||
)
|
||||
stack.enter_context(_ignore_backend_decomps())
|
||||
stack.enter_context(_compiling_state_context())
|
||||
stack.enter_context(custom_triton_ops_decomposition_ctx())
|
||||
stack.enter_context(torch.no_grad())
|
||||
|
||||
gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
|
||||
stack,
|
||||
mod,
|
||||
fake_args,
|
||||
trace_joint=False,
|
||||
pre_dispatch=pre_dispatch,
|
||||
decompositions=decomp_table,
|
||||
kwargs=fake_kwargs,
|
||||
decompositions=decomp_table,
|
||||
fake_params_buffers=fake_params_buffers,
|
||||
_record_nn_module_stack=True,
|
||||
)
|
||||
|
||||
def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm):
|
||||
|
|
@ -1578,7 +1630,7 @@ def _export_to_aten_ir_make_fx(
|
|||
produce_guards_callback=None,
|
||||
transform=lambda x: x,
|
||||
) -> ATenExportArtifact:
|
||||
def _make_fx_helper(mod, args, kwargs, **flags):
|
||||
def _make_fx_helper(stack, mod, args, kwargs, **flags):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
|
||||
|
|
@ -1796,18 +1848,20 @@ def _export_to_aten_ir_make_fx(
|
|||
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode,
|
||||
# otherwise aot_export_module will error out because it sees a mix of fake_modes.
|
||||
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about.
|
||||
with (
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
),
|
||||
_ignore_backend_decomps(),
|
||||
_compiling_state_context(),
|
||||
):
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
torch.nn.utils.stateless._reparametrize_module(
|
||||
mod,
|
||||
fake_params_buffers,
|
||||
tie_weights=True,
|
||||
strict=True,
|
||||
stack_weights=True,
|
||||
)
|
||||
)
|
||||
stack.enter_context(_ignore_backend_decomps())
|
||||
stack.enter_context(_compiling_state_context())
|
||||
gm, graph_signature = transform(_make_fx_helper)(
|
||||
stack,
|
||||
mod,
|
||||
fake_args,
|
||||
trace_joint=False,
|
||||
|
|
@ -1892,7 +1946,7 @@ def _non_strict_export(
|
|||
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {}
|
||||
|
||||
def _tuplify_outputs(aot_export):
|
||||
def _aot_export_non_strict(mod, args, kwargs=None, **flags):
|
||||
def _aot_export_non_strict(stack, mod, args, *, kwargs=None, **flags):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
class Wrapper(torch.nn.Module):
|
||||
|
|
@ -1936,8 +1990,8 @@ def _non_strict_export(
|
|||
wrapped_mod, new_preserved_call_signatures, module_call_specs
|
||||
)
|
||||
with ctx:
|
||||
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
|
||||
log.debug("Exported program from AOTAutograd:\n%s", gm)
|
||||
gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
|
||||
log.debug("Exported program from AOTAutograd:\n%s", gm)
|
||||
|
||||
sig.parameters = pytree.tree_map(_strip_root, sig.parameters)
|
||||
sig.buffers = pytree.tree_map(_strip_root, sig.buffers)
|
||||
|
|
@ -2016,7 +2070,9 @@ def _non_strict_export(
|
|||
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
|
||||
_override_builtin_ops(),
|
||||
):
|
||||
aten_export_artifact = _to_aten_func( # type: ignore[operator]
|
||||
# _to_aten_func is _export_to_aten_ir when using the default non-strict export
|
||||
# We need to pass positional args correctly
|
||||
aten_export_artifact = _to_aten_func(
|
||||
patched_mod,
|
||||
new_fake_args,
|
||||
new_fake_kwargs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user