mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Grab bag of (mostly) typing improvements (#158075)
Collects some scattershot improvements made while attempting to enable training for AOTInductor. Non-typing changes are: 1. Swapping a few custom searches for the output node in an FX graph for calling `graph.output_node()`. 2. Removing two unused parameters from `torch.export._unlift._unlift`. 3. Switching handles to constants in `cpp_wrapper_cpu` to use C++ references for memory efficiency. 4. Cleaning out unused, unexported imports from `torch/export/__init__.py`, and adding one missing export to `__all__`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158075 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
ad2dec1997
commit
22920c9138
|
|
@ -22,7 +22,7 @@ import sys
|
|||
import time
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, NamedTuple, TYPE_CHECKING
|
||||
from typing import Any, NamedTuple, Optional, overload, TYPE_CHECKING, TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -54,6 +54,7 @@ try:
|
|||
from torch._inductor.utils import fresh_cache
|
||||
except ImportError:
|
||||
from _dynamo.utils import clone_inputs, graph_break_reasons
|
||||
from _inductor.utils import fresh_cache
|
||||
|
||||
import torch._functorch.config
|
||||
from torch._functorch.aot_autograd import set_model_name
|
||||
|
|
@ -75,7 +76,10 @@ except ImportError:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
|
||||
_D = TypeVar("_D", bound=dict[str, Any])
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -766,7 +770,17 @@ def timed(
|
|||
return (time_total, result) if return_result else time_total
|
||||
|
||||
|
||||
def _normalize_bench_inputs(example_inputs) -> tuple[tuple[Any], Mapping[str, Any]]:
|
||||
@overload
|
||||
def _normalize_bench_inputs(example_inputs: _D) -> tuple[tuple[()], _D]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def _normalize_bench_inputs(
|
||||
example_inputs: Sequence[_T],
|
||||
) -> tuple[tuple[_T, ...], dict[str, Any]]: ...
|
||||
|
||||
|
||||
def _normalize_bench_inputs(example_inputs):
|
||||
# NOTE(bowbao): For huggingface benchmark, example_inputs are formatted as dictionary,
|
||||
# and consumed like `model(**example_inputs)`.
|
||||
# For other benchmarks, example_inputs are formatted as tuple and consumed
|
||||
|
|
@ -1671,7 +1685,7 @@ class BenchmarkRunner:
|
|||
self.grad_scaler = DummyGradScaler()
|
||||
self.autocast = contextlib.nullcontext
|
||||
self.autocast_arg = {}
|
||||
self.optimizer = None
|
||||
self.optimizer: Optional[torch.optim.Optimizer] = None
|
||||
self._args = None
|
||||
|
||||
def setup_amp(self, current_device=None):
|
||||
|
|
|
|||
|
|
@ -16,12 +16,16 @@ from parameterized import parameterized_class
|
|||
|
||||
import torch
|
||||
from torch._inductor.codecache import get_kernel_bin_format
|
||||
from torch._inductor.package import AOTICompiledModel, load_package, package_aoti
|
||||
from torch._inductor.package import load_package, package_aoti
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch._inductor.utils import fresh_cache
|
||||
from torch.export import Dim
|
||||
from torch.export.experimental import _ExportPackage
|
||||
from torch.export.pt2_archive._package import load_pt2, load_weights_to_pt2_contents
|
||||
from torch.export.pt2_archive._package import (
|
||||
AOTICompiledModel,
|
||||
load_pt2,
|
||||
load_weights_to_pt2_contents,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_FBCODE,
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ if typing.TYPE_CHECKING:
|
|||
Iterable,
|
||||
Iterator,
|
||||
KeysView,
|
||||
Sequence,
|
||||
ValuesView,
|
||||
)
|
||||
|
||||
|
|
@ -2137,8 +2138,18 @@ def clone_input(x, *, dtype=None):
|
|||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def clone_inputs(
|
||||
example_inputs: dict[str, Union[T, tuple[T, ...]]],
|
||||
) -> dict[str, list[T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ...
|
||||
|
||||
|
||||
def clone_inputs(example_inputs):
|
||||
res: Union[dict[Any, Any], list[Any]]
|
||||
res: Union[dict[str, Any], list[Any]]
|
||||
if type(example_inputs) is dict:
|
||||
res = dict(example_inputs)
|
||||
for key, value in res.items():
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import logging
|
|||
import os
|
||||
from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._inductor.config
|
||||
import torch.fx
|
||||
|
||||
from .standalone_compile import CompiledArtifact # noqa: TC001
|
||||
|
|
@ -15,6 +14,7 @@ from .standalone_compile import CompiledArtifact # noqa: TC001
|
|||
if TYPE_CHECKING:
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.export import ExportedProgram
|
||||
from torch.export.pt2_archive._package import AOTICompiledModel
|
||||
from torch.export.pt2_archive._package_weights import Weights
|
||||
from torch.types import FileLike
|
||||
|
||||
|
|
@ -223,7 +223,7 @@ def _aoti_compile_and_package_inner(
|
|||
not_strict_accuracy = check_accuracy == "accuracy"
|
||||
if not same_two_models(
|
||||
gm,
|
||||
compiled_model,
|
||||
compiled_model, # type: ignore[arg-type]
|
||||
args,
|
||||
only_fwd=True,
|
||||
require_fp64=not_strict_accuracy,
|
||||
|
|
@ -238,7 +238,7 @@ def _aoti_compile_and_package_inner(
|
|||
|
||||
def aoti_load_package(
|
||||
path: FileLike, run_single_threaded: bool = False, device_index: int = -1
|
||||
) -> Any: # type: ignore[type-arg]
|
||||
) -> AOTICompiledModel:
|
||||
"""
|
||||
Loads the model from the PT2 package.
|
||||
|
||||
|
|
|
|||
|
|
@ -630,10 +630,10 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
), "Expect all constants to be Tensor"
|
||||
for idx, constants_key in enumerate(V.graph.constants.keys()):
|
||||
if V.graph.aot_mode:
|
||||
# Weights are stored in constants_ and owned by RAIIAtenTensorHandle there.
|
||||
# Weights are stored in constants_ and owned by ConstantHandle there.
|
||||
# Don't call std::move here because it will cause constants_ to lose the ownership.
|
||||
self.prefix.writeline(
|
||||
f"""[[maybe_unused]] auto {constants_key} = constants_->at({idx});"""
|
||||
f"""[[maybe_unused]] auto& {constants_key} = constants_->at({idx});"""
|
||||
)
|
||||
else:
|
||||
# Append constants as inputs to the graph
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from torch._functorch._aot_autograd.subclass_parametrization import (
|
|||
)
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_export_module,
|
||||
GraphOutputName,
|
||||
make_boxed_func,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
|
|
@ -429,7 +430,7 @@ def _unlift_graph(
|
|||
|
||||
from torch.export._unlift import _unlift
|
||||
|
||||
outputs = list(gm.graph.nodes)[-1].args[0]
|
||||
outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type]
|
||||
mutated_outputs = []
|
||||
buffer_mutations = graph_signature.buffers_to_mutate
|
||||
user_input_mutations = graph_signature.user_inputs_to_mutate
|
||||
|
|
@ -438,10 +439,11 @@ def _unlift_graph(
|
|||
value: Optional[Union[FQN, GraphInputName]] = None
|
||||
|
||||
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens):
|
||||
if out.name in buffer_mutations:
|
||||
value = buffer_mutations[out.name]
|
||||
elif out.name in user_input_mutations:
|
||||
value = user_input_mutations[out.name]
|
||||
name = GraphOutputName(out.name)
|
||||
if name in buffer_mutations:
|
||||
value = buffer_mutations[name]
|
||||
elif name in user_input_mutations:
|
||||
value = user_input_mutations[name]
|
||||
|
||||
mutated_outputs.append(value)
|
||||
|
||||
|
|
@ -451,8 +453,6 @@ def _unlift_graph(
|
|||
mutated_outputs,
|
||||
pytree.LeafSpec(),
|
||||
None,
|
||||
state_dict,
|
||||
{},
|
||||
)
|
||||
return unlifted_gm
|
||||
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ def load_package(
|
|||
run_single_threaded: bool = False,
|
||||
num_runners: int = 1,
|
||||
device_index: int = -1,
|
||||
) -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
) -> AOTICompiledModel:
|
||||
try:
|
||||
pt2_contents = load_pt2(
|
||||
path,
|
||||
|
|
|
|||
|
|
@ -1,59 +1,38 @@
|
|||
import builtins
|
||||
import copy
|
||||
import dataclasses
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
import zipfile
|
||||
from collections.abc import Iterator
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.types import FileLike
|
||||
from torch.utils._pytree import (
|
||||
FlattenFunc,
|
||||
FromDumpableContextFn,
|
||||
ToDumpableContextFn,
|
||||
UnflattenFunc,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdditionalInputs",
|
||||
"Constraint",
|
||||
"Dim",
|
||||
"ExportBackwardSignature",
|
||||
"ExportGraphSignature",
|
||||
"ExportedProgram",
|
||||
"CustomDecompTable",
|
||||
"default_decompositions",
|
||||
"Dim",
|
||||
"dims",
|
||||
"draft_export",
|
||||
"export_for_training",
|
||||
"export",
|
||||
"ExportBackwardSignature",
|
||||
"ExportedProgram",
|
||||
"ExportGraphSignature",
|
||||
"FlatArgsAdapter",
|
||||
"load",
|
||||
"ModuleCallEntry",
|
||||
"ModuleCallSignature",
|
||||
"default_decompositions",
|
||||
"dims",
|
||||
"export",
|
||||
"export_for_training",
|
||||
"load",
|
||||
"register_dataclass",
|
||||
"save",
|
||||
"ShapesCollection",
|
||||
"unflatten",
|
||||
"FlatArgsAdapter",
|
||||
"UnflattenedModule",
|
||||
"AdditionalInputs",
|
||||
"draft_export",
|
||||
]
|
||||
|
||||
# To make sure export specific custom ops are loaded
|
||||
|
|
@ -82,9 +61,9 @@ PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]]
|
|||
def export_for_training(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
|
|
@ -181,9 +160,9 @@ def export_for_training(
|
|||
def export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
|
|
@ -540,9 +519,9 @@ def load(
|
|||
def draft_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
) -> ExportedProgram:
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._logging._internal
|
||||
import torch._logging.structured
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.passes.insert_custom_op_guards import (
|
||||
get_op_profiles,
|
||||
|
|
@ -362,7 +362,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|||
def draft_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
kwargs: Optional[Mapping[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
|
|
|
|||
|
|
@ -1723,6 +1723,7 @@ def _export_to_aten_ir_make_fx(
|
|||
gm.graph.eliminate_dead_code(_is_impure)
|
||||
|
||||
# create graph signature
|
||||
assert out_spec.spec is not None, "out_spec.spec is None!"
|
||||
input_names = _graph_input_names(gm)
|
||||
output_names = _graph_output_names(gm)
|
||||
sig = GraphSignature(
|
||||
|
|
@ -1737,7 +1738,7 @@ def _export_to_aten_ir_make_fx(
|
|||
buffers_to_mutate={},
|
||||
user_inputs_to_mutate={},
|
||||
in_spec=in_spec,
|
||||
out_spec=out_spec, # type: ignore[arg-type]
|
||||
out_spec=out_spec.spec,
|
||||
backward_signature=None,
|
||||
input_tokens=[],
|
||||
output_tokens=[],
|
||||
|
|
|
|||
|
|
@ -138,12 +138,7 @@ def _insert_copy_for_mutations(
|
|||
Find the all the buffers and inputs that were mutated and insert copy_
|
||||
operators to reflect mutations.
|
||||
"""
|
||||
output_node = None
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
output_node = node
|
||||
break
|
||||
assert output_node is not None
|
||||
output_node = gm.graph.output_node()
|
||||
outputs = pytree.tree_flatten(output_node.args)[0]
|
||||
assert len(outputs) == len(mutated_outputs)
|
||||
|
||||
|
|
@ -169,13 +164,13 @@ def _insert_copy_for_mutations(
|
|||
)
|
||||
return_nodes_to_copy[return_node] = copy_node
|
||||
|
||||
output_args = [
|
||||
output_args = tuple(
|
||||
return_nodes_to_copy[node] if node in return_nodes_to_copy else node
|
||||
for node in user_output_nodes
|
||||
]
|
||||
)
|
||||
with gm.graph.inserting_before(output_node):
|
||||
# Only return user outputs
|
||||
new_output = gm.graph.output(tuple(output_args))
|
||||
new_output = gm.graph.output(output_args)
|
||||
output_node.replace_all_uses_with(new_output)
|
||||
gm.graph.erase_node(output_node)
|
||||
new_output.name = output_node.name
|
||||
|
|
@ -199,19 +194,18 @@ def _get_codegen(
|
|||
"""
|
||||
if forward_arg_names:
|
||||
names = forward_arg_names
|
||||
elif (
|
||||
in_spec.type == tuple
|
||||
and in_spec.num_children == 2
|
||||
and in_spec.children_specs[0].type == tuple
|
||||
and in_spec.children_specs[1].type == dict
|
||||
):
|
||||
# if in_spec contains the args (tuple) and kwargs (dict)
|
||||
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
||||
# add kwarg names
|
||||
names.extend(in_spec.children_specs[1].context)
|
||||
else:
|
||||
if (
|
||||
in_spec.type == tuple
|
||||
and in_spec.num_children == 2
|
||||
and in_spec.children_specs[0].type == tuple
|
||||
and in_spec.children_specs[1].type == dict
|
||||
):
|
||||
# if in_spec contains the args (tuple) and kwargs (dict)
|
||||
names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)]
|
||||
# add kwarg names
|
||||
names.extend(in_spec.children_specs[1].context)
|
||||
else:
|
||||
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
||||
names = [f"arg_{i}" for i in range(in_spec.num_children)]
|
||||
|
||||
return _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
|
|
@ -228,8 +222,6 @@ def _unlift(
|
|||
mutated_outputs: Sequence[Optional[str]],
|
||||
in_spec: pytree.TreeSpec,
|
||||
out_spec: Optional[pytree.TreeSpec],
|
||||
state_dict: dict[str, Any],
|
||||
constants: dict[str, Any],
|
||||
forward_arg_names: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
|
|
@ -427,7 +419,7 @@ def _create_stateful_graph_module(
|
|||
return stateful_gm
|
||||
|
||||
|
||||
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module:
|
||||
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule:
|
||||
# TODO T206340015
|
||||
if ep.verifiers[0].dialect != "TRAINING":
|
||||
ep = _remove_effect_tokens(ep)
|
||||
|
|
@ -482,14 +474,13 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Modu
|
|||
)
|
||||
]
|
||||
|
||||
assert ep.call_spec.in_spec is not None
|
||||
new_gm = _unlift(
|
||||
new_gm,
|
||||
lifted_inputs,
|
||||
mutated_outputs,
|
||||
ep.call_spec.in_spec,
|
||||
ep.call_spec.out_spec,
|
||||
ep.state_dict,
|
||||
ep.constants,
|
||||
forward_arg_names=forward_arg_names,
|
||||
)
|
||||
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import copy
|
|||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import typing
|
||||
import typing_extensions
|
||||
|
|
@ -14,11 +13,15 @@ from torch.export.experimental._utils import _get_main_cpp_file, _get_make_file
|
|||
from torch.export.exported_program import _decompose_exported_program
|
||||
|
||||
|
||||
_InputT = typing_extensions.ParamSpec("_InputT")
|
||||
_RetT = typing.TypeVar("_RetT")
|
||||
|
||||
|
||||
__all__ = [] # type: ignore[var-annotated]
|
||||
|
||||
|
||||
def _copy_graph_module_and_signature(
|
||||
ep: torch.fx.GraphModule,
|
||||
ep: torch.export.ExportedProgram,
|
||||
) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]:
|
||||
# copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(),
|
||||
# and this can break placeholder names in some particular cases.
|
||||
|
|
@ -36,7 +39,7 @@ def _copy_graph_module_and_signature(
|
|||
for old_node, new_node in zip(old_phs, new_phs):
|
||||
new_node.name = old_node.name
|
||||
|
||||
return gm, new_graph_signature # type: ignore[return-value]
|
||||
return gm, new_graph_signature
|
||||
|
||||
|
||||
def _remove_detach_pass(
|
||||
|
|
@ -81,18 +84,27 @@ def _export_forward_backward(
|
|||
return ep._update(gm, new_graph_signature)
|
||||
|
||||
|
||||
@typing.no_type_check
|
||||
def _sticky_export(forward_func, dynamic_shapes_callback=None):
|
||||
def _sticky_export(
|
||||
forward_func: typing.Callable[_InputT, _RetT],
|
||||
dynamic_shapes_callback: typing.Optional[
|
||||
typing.Callable[
|
||||
_InputT,
|
||||
typing.Union[
|
||||
list[typing.Any], dict[str, typing.Any], tuple[typing.Any, ...]
|
||||
],
|
||||
]
|
||||
] = None,
|
||||
) -> typing.Callable[_InputT, _RetT]:
|
||||
"""
|
||||
Lazily export the model on first forward call.
|
||||
Usage:
|
||||
model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback)
|
||||
"""
|
||||
model = forward_func.__self__
|
||||
original_forward = forward_func.__func__
|
||||
model = forward_func.__self__ # type: ignore[attr-defined]
|
||||
original_forward = forward_func.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@functools.wraps(forward_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args: _InputT.args, **kwargs: _InputT.kwargs) -> _RetT:
|
||||
# Unpatch forward to avoid recursion during export
|
||||
model.forward = types.MethodType(original_forward, model)
|
||||
|
||||
|
|
@ -107,7 +119,7 @@ def _sticky_export(forward_func, dynamic_shapes_callback=None):
|
|||
kwargs,
|
||||
dynamic_shapes=dynamic_shapes_spec,
|
||||
).module()
|
||||
wrapper._exported_artifact = exported
|
||||
wrapper._exported_artifact = exported # type: ignore[attr-defined]
|
||||
finally:
|
||||
# Restore the wrapper after export
|
||||
model.forward = wrapper
|
||||
|
|
@ -123,10 +135,6 @@ class _ExportMethod:
|
|||
fallbacks: list[torch.export.ExportedProgram]
|
||||
|
||||
|
||||
_InputT = typing_extensions.ParamSpec("_InputT")
|
||||
_RetT = typing.TypeVar("_RetT")
|
||||
|
||||
|
||||
class _ExportPackage:
|
||||
"""
|
||||
An export package is a collection of torch.export()-ed PyTorch models consisting of
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ import functools
|
|||
import operator
|
||||
import types
|
||||
import warnings
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, final, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
|
|
@ -325,7 +325,7 @@ def default_decompositions() -> "CustomDecompTable":
|
|||
|
||||
|
||||
def _decompose_and_get_gm_with_new_signature_constants(
|
||||
ep,
|
||||
ep: "ExportedProgram",
|
||||
*,
|
||||
cia_to_decomp: dict[torch._ops.OperatorBase, Callable],
|
||||
python_decomp_table: dict[torch._ops.OperatorBase, Callable],
|
||||
|
|
@ -384,9 +384,11 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
# Fix the graph output signature to be tuple if scalar
|
||||
out_spec = mod._out_spec
|
||||
|
||||
assert isinstance(mod.graph._codegen, _PyTreeCodeGen)
|
||||
orig_arg_names = mod.graph._codegen.pytree_info.orig_args
|
||||
|
||||
# aot_export expect the return type to always be a tuple.
|
||||
assert out_spec is not None
|
||||
if out_spec.type not in (list, tuple):
|
||||
out_spec = pytree.TreeSpec(tuple, None, [out_spec])
|
||||
|
||||
|
|
@ -610,7 +612,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}")
|
||||
|
||||
new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
|
||||
new_outputs = list(gm.graph.nodes)[-1].args[0]
|
||||
new_outputs: tuple[torch.fx.Node, ...] = tuple(gm.graph.output_node().args[0]) # type: ignore[arg-type]
|
||||
|
||||
# rename the placeholders
|
||||
assert len(new_placeholders) == len(old_placeholders)
|
||||
|
|
@ -654,9 +656,9 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
|
||||
# update output specs
|
||||
gm.recompile()
|
||||
for i, name in enumerate(_graph_output_names(gm)):
|
||||
if isinstance(new_outputs[i], torch.fx.Node):
|
||||
new_outputs[i].name = name
|
||||
for output, name in zip(new_outputs, _graph_output_names(gm)):
|
||||
if name is not None:
|
||||
output.name = name
|
||||
|
||||
# To match the output target with correct input for input mutations
|
||||
# need to find the old to new placeholder map
|
||||
|
|
@ -727,7 +729,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
for i, spec in enumerate(ep.graph_signature.input_specs)
|
||||
if isinstance(spec.arg, TensorArgument)
|
||||
}
|
||||
for i, node in enumerate(new_outputs[len(output_specs) :]):
|
||||
for node in new_outputs[len(output_specs) :]:
|
||||
source = gradients[node.name]
|
||||
spec = specs[source] # type: ignore[index]
|
||||
if spec.kind == InputKind.PARAMETER:
|
||||
|
|
@ -1208,7 +1210,9 @@ class ExportedProgram:
|
|||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def call_spec(self):
|
||||
CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"])
|
||||
class CallSpec(NamedTuple):
|
||||
in_spec: Optional[pytree.TreeSpec]
|
||||
out_spec: Optional[pytree.TreeSpec]
|
||||
|
||||
if len(self.module_call_graph) == 0:
|
||||
return CallSpec(in_spec=None, out_spec=None)
|
||||
|
|
@ -1364,7 +1368,7 @@ class ExportedProgram:
|
|||
)
|
||||
return string
|
||||
|
||||
def module(self) -> torch.nn.Module:
|
||||
def module(self) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Returns a self contained GraphModule with all the parameters/buffers inlined.
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user