[TS2E] Remove reference to torch.onnx internals (#132186)

Instead, this PR moves the code to the converter to avoid dependence. Feel free to refactor it afterward.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132186
Approved by: https://github.com/angelayi
This commit is contained in:
Justin Chu 2024-08-01 15:08:02 +00:00 committed by PyTorch MergeBot
parent d7d6190493
commit 0d88dd0f77

View File

@ -2,12 +2,14 @@
import builtins
import logging
import operator
import typing
import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
import torch.export._trace
from torch import _C
from torch.export.exported_program import ExportedProgram
from torch.export.graph_signature import (
ConstantArgument,
@ -18,12 +20,101 @@ from torch.export.graph_signature import (
TensorArgument,
)
from torch.fx import subgraph_rewriter
from torch.onnx.utils import _create_jit_graph
log = logging.getLogger(__name__)
def _get_param_count_list(method_graph, args_params):
param_count_list = []
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
if "PackedParams" in str(input_.type()):
in_vars, _ = torch.jit._flatten(arg_params_)
param_count_list.append(len(in_vars))
else:
param_count_list.append(arg_params_ is not None)
return param_count_list
def _trace_and_get_graph_from_model(model, args):
# A basic sanity check: make sure the state_dict keys are the same
# before and after running the model. Fail fast!
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
# Disable Autocast cache because it replaces kernel's weight and bias
# by (undesired) constants.
# No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
torch.set_autocast_cache_enabled(False)
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
model,
args,
strict=False,
_force_outplace=False,
_return_inputs_states=True,
)
torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
raise RuntimeError(
"state_dict changed after running the tracer; "
"something weird is happening in your model!"
)
return trace_graph, torch_out
def _create_jit_graph(
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
torch_out = None
if isinstance(model, torch.jit.ScriptModule):
try:
graph = model.forward.graph # type: ignore[attr-defined]
except AttributeError as e:
raise RuntimeError("'forward' method must be a script method") from e
_C._jit_pass_onnx_function_substitution(graph)
freezed_module = _C._freeze_module(
typing.cast(_C.ScriptModule, model._c), preserveParameters=True
)
module, params = _C._jit_onnx_list_model_parameters(freezed_module)
method_graph = module._get_method("forward").graph
args_params = tuple(args) + tuple(params)
param_count_list = _get_param_count_list(method_graph, args_params)
in_vars, _ = torch.jit._flatten(args_params)
graph = _C._propagate_and_assign_input_shapes(
method_graph, tuple(in_vars), param_count_list, False, False
)
return graph, params, torch_out, module
# torch.jit.ScriptFunction
params = []
graph = model.graph
_C._jit_pass_onnx_function_substitution(graph)
param_count_list = _get_param_count_list(graph, args)
graph = _C._propagate_and_assign_input_shapes(
graph, flattened_args, param_count_list, False, False
)
return graph, params, torch_out, None
graph, torch_out = _trace_and_get_graph_from_model(model, args)
_C._jit_pass_onnx_lint(graph)
state_dict = torch.jit._unique_state_dict(model)
params = list(state_dict.values())
graph_inputs = list(graph.inputs())
user_input_num = len(graph_inputs) - len(state_dict)
param_names = list(state_dict.keys())
for i, inp in enumerate(graph_inputs):
if i >= user_input_num:
inp.setDebugName(param_names[i - user_input_num])
_C._jit_pass_onnx_function_substitution(graph)
return graph, params, torch_out, None
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
def pattern(im, dim, scale):
sym_size_int = torch.ops.aten.sym_size.int(im, dim)