Grab bag of (mostly) typing improvements (#158075)

Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are:

1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`.
2. Removing two unused parameters from `torch.export._unlift._unlift`.
3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency.
4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075
Approved by: https://github.com/Skylion007
This commit is contained in:
Benjamin Glass 2025-07-21 15:42:02 +00:00 committed by PyTorch MergeBot
parent ad2dec1997
commit 22920c9138
13 changed files with 126 additions and 114 deletions

View File

@ -22,7 +22,7 @@ import sys
import time
import weakref
from contextlib import contextmanager
from typing import Any, NamedTuple, TYPE_CHECKING
from typing import Any, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar
from unittest.mock import MagicMock
import numpy as np
@ -54,6 +54,7 @@ try:
from torch._inductor.utils import fresh_cache
except ImportError:
from _dynamo.utils import clone_inputs, graph_break_reasons
from _inductor.utils import fresh_cache
import torch._functorch.config
from torch._functorch.aot_autograd import set_model_name
@ -75,7 +76,10 @@ except ImportError:
if TYPE_CHECKING:
from collections.abc import Mapping
from collections.abc import Sequence
_D = TypeVar("_D", bound=dict[str, Any])
_T = TypeVar("_T")
log = logging.getLogger(__name__)
@ -766,7 +770,17 @@ def timed(
return (time_total, result) if return_result else time_total
def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]:
@overload
def _normalize_bench_inputs(example_inputs: _D) -> tuple[tuple[()], _D]: ...
@overload
def _normalize_bench_inputs(
example_inputs: Sequence[_T],
) -> tuple[tuple[_T, ...], dict[str, Any]]: ...
def _normalize_bench_inputs(example_inputs):
# NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
# and consumed like `model(**example_inputs)`.
# For other benchmarks, example_inputs are formatted as tuple and consumed
@ -1671,7 +1685,7 @@ class BenchmarkRunner:
self.grad_scaler = DummyGradScaler()
self.autocast = contextlib.nullcontext
self.autocast_arg = {}
self.optimizer = None
self.optimizer: Optional[torch.optim.Optimizer] = None
self._args = None
def setup_amp(self, current_device=None):

View File

@ -16,12 +16,16 @@ from parameterized import parameterized_class
import torch
from torch._inductor.codecache import get_kernel_bin_format
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
from torch._inductor.package import load_package, package_aoti
from torch._inductor.test_case import TestCase
from torch._inductor.utils import fresh_cache
from torch.export import Dim
from torch.export.experimental import _ExportPackage
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
from torch.export.pt2_archive._package import (
AOTICompiledModel,
load_pt2,
load_weights_to_pt2_contents,
)
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_utils import (
IS_FBCODE,

View File

@ -102,6 +102,7 @@ if typing.TYPE_CHECKING:
Iterable,
Iterator,
KeysView,
Sequence,
ValuesView,
)
@ -2137,8 +2138,18 @@ def clone_input(x, *, dtype=None):
return result
@overload
def clone_inputs(
example_inputs: dict[str, Union[T, tuple[T, ...]]],
) -> dict[str, list[T]]: ...
@overload
def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ...
def clone_inputs(example_inputs):
res: Union[dict[Any, Any], list[Any]]
res: Union[dict[str, Any], list[Any]]
if type(example_inputs) is dict:
res = dict(example_inputs)
for key, value in res.items():

View File

@ -6,7 +6,6 @@ import logging
import os
from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union
import torch._inductor.config
import torch.fx
from .standalone_compile import CompiledArtifact # noqa: TC001
@ -15,6 +14,7 @@ from .standalone_compile import CompiledArtifact # noqa: TC001
if TYPE_CHECKING:
from torch._inductor.utils import InputType
from torch.export import ExportedProgram
from torch.export.pt2_archive._package import AOTICompiledModel
from torch.export.pt2_archive._package_weights import Weights
from torch.types import FileLike
@ -223,7 +223,7 @@ def _aoti_compile_and_package_inner(
not_strict_accuracy = check_accuracy == "accuracy"
if not same_two_models(
gm,
compiled_model,
compiled_model, # type: ignore[arg-type]
args,
only_fwd=True,
require_fp64=not_strict_accuracy,
@ -238,7 +238,7 @@ def _aoti_compile_and_package_inner(
def aoti_load_package(
path: FileLike, run_single_threaded: bool = False, device_index: int = -1
) -> Any: # type: ignore[type-arg]
) -> AOTICompiledModel:
"""
Loads the model from the PT2 package.

View File

@ -630,10 +630,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
), "Expect all constants to be Tensor"
for idx, constants_key in enumerate(V.graph.constants.keys()):
if V.graph.aot_mode:
# Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
# Weights are stored in constants_ and owned by ConstantHandle there.
# Don't call std::move here because it will cause constants_ to lose the ownership.
self.prefix.writeline(
f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});"""
f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});"""
)
else:
# Append constants as inputs to the graph

View File

@ -53,6 +53,7 @@ from torch._functorch._aot_autograd.subclass_parametrization import (
)
from torch._functorch.aot_autograd import (
aot_export_module,
GraphOutputName,
make_boxed_func,
SerializableAOTDispatchCompiler,
)
@ -429,7 +430,7 @@ def _unlift_graph(
from torch.export._unlift import _unlift
outputs = list(gm.graph.nodes)[-1].args[0]
outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type]
mutated_outputs = []
buffer_mutations = graph_signature.buffers_to_mutate
user_input_mutations = graph_signature.user_inputs_to_mutate
@ -438,10 +439,11 @@ def _unlift_graph(
value: Optional[Union[FQN, GraphInputName]] = None
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
if out.name in buffer_mutations:
value = buffer_mutations[out.name]
elif out.name in user_input_mutations:
value = user_input_mutations[out.name]
name = GraphOutputName(out.name)
if name in buffer_mutations:
value = buffer_mutations[name]
elif name in user_input_mutations:
value = user_input_mutations[name]
mutated_outputs.append(value)
@ -451,8 +453,6 @@ def _unlift_graph(
mutated_outputs,
pytree.LeafSpec(),
None,
state_dict,
{},
)
return unlifted_gm

View File

@ -105,7 +105,7 @@ def load_package(
run_single_threaded: bool = False,
num_runners: int = 1,
device_index: int = -1,
) -> AOTICompiledModel: # type: ignore[type-arg]
) -> AOTICompiledModel:
try:
pt2_contents = load_pt2(
path,

View File

@ -1,59 +1,38 @@
import builtins
import copy
import dataclasses
import inspect
import os
import sys
import typing
import warnings
import zipfile
from collections.abc import Iterator
from enum import auto, Enum
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from collections.abc import Mapping
from typing import Any, Callable, Optional, Union
from typing_extensions import deprecated
import torch
import torch.utils._pytree as pytree
from torch.fx._compatibility import compatibility
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.types import FileLike
from torch.utils._pytree import (
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)
if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
__all__ = [
"AdditionalInputs",
"Constraint",
"Dim",
"ExportBackwardSignature",
"ExportGraphSignature",
"ExportedProgram",
"CustomDecompTable",
"default_decompositions",
"Dim",
"dims",
"draft_export",
"export_for_training",
"export",
"ExportBackwardSignature",
"ExportedProgram",
"ExportGraphSignature",
"FlatArgsAdapter",
"load",
"ModuleCallEntry",
"ModuleCallSignature",
"default_decompositions",
"dims",
"export",
"export_for_training",
"load",
"register_dataclass",
"save",
"ShapesCollection",
"unflatten",
"FlatArgsAdapter",
"UnflattenedModule",
"AdditionalInputs",
"draft_export",
]
# To make sure export specific custom ops are loaded
@ -82,9 +61,9 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
def export_for_training(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
@ -181,9 +160,9 @@ def export_for_training(
def export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
@ -540,9 +519,9 @@ def load(
def draft_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
) -> ExportedProgram:

View File

@ -4,13 +4,13 @@ import logging
import os
import re
import tempfile
from collections.abc import Mapping
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Callable, Optional, Union
import torch
import torch._logging._internal
import torch._logging.structured
import torch.utils._pytree as pytree
from torch._export.passes.insert_custom_op_guards import (
get_op_profiles,
@ -362,7 +362,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
def draft_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
kwargs: Optional[dict[str, Any]] = None,
kwargs: Optional[Mapping[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),

View File

@ -1723,6 +1723,7 @@ def _export_to_aten_ir_make_fx(
gm.graph.eliminate_dead_code(_is_impure)
# create graph signature
assert out_spec.spec is not None, "out_spec.spec is None!"
input_names = _graph_input_names(gm)
output_names = _graph_output_names(gm)
sig = GraphSignature(
@ -1737,7 +1738,7 @@ def _export_to_aten_ir_make_fx(
buffers_to_mutate={},
user_inputs_to_mutate={},
in_spec=in_spec,
out_spec=out_spec, # type: ignore[arg-type]
out_spec=out_spec.spec,
backward_signature=None,
input_tokens=[],
output_tokens=[],

View File

@ -138,12 +138,7 @@ def _insert_copy_for_mutations(
Find the all the buffers and inputs that were mutated and insert copy_
operators to reflect mutations.
"""
output_node = None
for node in gm.graph.nodes:
if node.op == "output":
output_node = node
break
assert output_node is not None
output_node = gm.graph.output_node()
outputs = pytree.tree_flatten(output_node.args)[0]
assert len(outputs) == len(mutated_outputs)
@ -169,13 +164,13 @@ def _insert_copy_for_mutations(
)
return_nodes_to_copy[return_node] = copy_node
output_args = [
output_args = tuple(
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
for node in user_output_nodes
]
)
with gm.graph.inserting_before(output_node):
# Only return user outputs
new_output = gm.graph.output(tuple(output_args))
new_output = gm.graph.output(output_args)
output_node.replace_all_uses_with(new_output)
gm.graph.erase_node(output_node)
new_output.name = output_node.name
@ -199,19 +194,18 @@ def _get_codegen(
"""
if forward_arg_names:
names = forward_arg_names
elif (
in_spec.type == tuple
and in_spec.num_children == 2
and in_spec.children_specs[0].type == tuple
and in_spec.children_specs[1].type == dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
# add kwarg names
names.extend(in_spec.children_specs[1].context)
else:
if (
in_spec.type == tuple
and in_spec.num_children == 2
and in_spec.children_specs[0].type == tuple
and in_spec.children_specs[1].type == dict
):
# if in_spec contains the args (tuple) and kwargs (dict)
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
# add kwarg names
names.extend(in_spec.children_specs[1].context)
else:
names = [f"arg_{i}" for i in range(in_spec.num_children)]
names = [f"arg_{i}" for i in range(in_spec.num_children)]
return _PyTreeCodeGen(
_PyTreeInfo(
@ -228,8 +222,6 @@ def _unlift(
mutated_outputs: Sequence[Optional[str]],
in_spec: pytree.TreeSpec,
out_spec: Optional[pytree.TreeSpec],
state_dict: dict[str, Any],
constants: dict[str, Any],
forward_arg_names: Optional[list[str]] = None,
):
"""
@ -427,7 +419,7 @@ def _create_stateful_graph_module(
return stateful_gm
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule:
# TODO T206340015
if ep.verifiers[0].dialect != "TRAINING":
ep = _remove_effect_tokens(ep)
@ -482,14 +474,13 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
)
]
assert ep.call_spec.in_spec is not None
new_gm = _unlift(
new_gm,
lifted_inputs,
mutated_outputs,
ep.call_spec.in_spec,
ep.call_spec.out_spec,
ep.state_dict,
ep.constants,
forward_arg_names=forward_arg_names,
)
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)

View File

@ -2,7 +2,6 @@ import copy
import dataclasses
import functools
import os
import tempfile
import types
import typing
import typing_extensions
@ -14,11 +13,15 @@ from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file
from torch.export.exported_program import _decompose_exported_program
_InputT = typing_extensions.ParamSpec("_InputT")
_RetT = typing.TypeVar("_RetT")
__all__ = [] # type: ignore[var-annotated]
def _copy_graph_module_and_signature(
ep: torch.fx.GraphModule,
ep: torch.export.ExportedProgram,
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
# and this can break placeholder names in some particular cases.
@ -36,7 +39,7 @@ def _copy_graph_module_and_signature(
for old_node, new_node in zip(old_phs, new_phs):
new_node.name = old_node.name
return gm, new_graph_signature # type: ignore[return-value]
return gm, new_graph_signature
def _remove_detach_pass(
@ -81,18 +84,27 @@ def _export_forward_backward(
return ep._update(gm, new_graph_signature)
@typing.no_type_check
def _sticky_export(forward_func, dynamic_shapes_callback=None):
def _sticky_export(
forward_func: typing.Callable[_InputT, _RetT],
dynamic_shapes_callback: typing.Optional[
typing.Callable[
_InputT,
typing.Union[
list[typing.Any], dict[str, typing.Any], tuple[typing.Any, ...]
],
]
] = None,
) -> typing.Callable[_InputT, _RetT]:
"""
Lazily export the model on first forward call.
Usage:
model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback)
"""
model = forward_func.__self__
original_forward = forward_func.__func__
model = forward_func.__self__ # type: ignore[attr-defined]
original_forward = forward_func.__func__ # type: ignore[attr-defined]
@functools.wraps(forward_func)
def wrapper(*args, **kwargs):
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
# Unpatch forward to avoid recursion during export
model.forward = types.MethodType(original_forward, model)
@ -107,7 +119,7 @@ def _sticky_export(forward_func, dynamic_shapes_callback=None):
kwargs,
dynamic_shapes=dynamic_shapes_spec,
).module()
wrapper._exported_artifact = exported
wrapper._exported_artifact = exported # type: ignore[attr-defined]
finally:
# Restore the wrapper after export
model.forward = wrapper
@ -123,10 +135,6 @@ class _ExportMethod:
fallbacks: list[torch.export.ExportedProgram]
_InputT = typing_extensions.ParamSpec("_InputT")
_RetT = typing.TypeVar("_RetT")
class _ExportPackage:
"""
An export package is a collection of torch.export()-ed PyTorch models consisting of

View File

@ -7,10 +7,10 @@ import functools
import operator
import types
import warnings
from collections import defaultdict, namedtuple
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union
from torch._guards import tracing, TracingContext
from torch._higher_order_ops.utils import autograd_not_implemented
@ -325,7 +325,7 @@ def default_decompositions() -> "CustomDecompTable":
def _decompose_and_get_gm_with_new_signature_constants(
ep,
ep: "ExportedProgram",
*,
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
@ -384,9 +384,11 @@ def _decompose_and_get_gm_with_new_signature_constants(
# Fix the graph output signature to be tuple if scalar
out_spec = mod._out_spec
assert isinstance(mod.graph._codegen, _PyTreeCodeGen)
orig_arg_names = mod.graph._codegen.pytree_info.orig_args
# aot_export expect the return type to always be a tuple.
assert out_spec is not None
if out_spec.type not in (list, tuple):
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
@ -610,7 +612,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
new_outputs = list(gm.graph.nodes)[-1].args[0]
new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type]
# rename the placeholders
assert len(new_placeholders) == len(old_placeholders)
@ -654,9 +656,9 @@ def _decompose_and_get_gm_with_new_signature_constants(
# update output specs
gm.recompile()
for i, name in enumerate(_graph_output_names(gm)):
if isinstance(new_outputs[i], torch.fx.Node):
new_outputs[i].name = name
for output, name in zip(new_outputs, _graph_output_names(gm)):
if name is not None:
output.name = name
# To match the output target with correct input for input mutations
# need to find the old to new placeholder map
@ -727,7 +729,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
for i, spec in enumerate(ep.graph_signature.input_specs)
if isinstance(spec.arg, TensorArgument)
}
for i, node in enumerate(new_outputs[len(output_specs) :]):
for node in new_outputs[len(output_specs) :]:
source = gradients[node.name]
spec = specs[source] # type: ignore[index]
if spec.kind == InputKind.PARAMETER:
@ -1208,7 +1210,9 @@ class ExportedProgram:
@property
@compatibility(is_backward_compatible=False)
def call_spec(self):
CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
class CallSpec(NamedTuple):
in_spec: Optional[pytree.TreeSpec]
out_spec: Optional[pytree.TreeSpec]
if len(self.module_call_graph) == 0:
return CallSpec(in_spec=None, out_spec=None)
@ -1364,7 +1368,7 @@ class ExportedProgram:
)
return string
def module(self) -> torch.nn.Module:
def module(self) -> torch.fx.GraphModule:
"""
Returns a self contained GraphModule with all the parameters/buffers inlined.
"""