mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TS2E] Remove reference to torch.onnx internals (#132186)
Instead, this PR moves the code to the converter to avoid dependence. Feel free to refactor it afterward. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132186 Approved by: https://github.com/angelayi
This commit is contained in:
parent
d7d6190493
commit
0d88dd0f77
|
|
@ -2,12 +2,14 @@
|
|||
import builtins
|
||||
import logging
|
||||
import operator
|
||||
import typing
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.export._trace
|
||||
from torch import _C
|
||||
from torch.export.exported_program import ExportedProgram
|
||||
from torch.export.graph_signature import (
|
||||
ConstantArgument,
|
||||
|
|
@ -18,12 +20,101 @@ from torch.export.graph_signature import (
|
|||
TensorArgument,
|
||||
)
|
||||
from torch.fx import subgraph_rewriter
|
||||
from torch.onnx.utils import _create_jit_graph
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_param_count_list(method_graph, args_params):
|
||||
param_count_list = []
|
||||
for input_, arg_params_ in zip(method_graph.inputs(), args_params):
|
||||
if "PackedParams" in str(input_.type()):
|
||||
in_vars, _ = torch.jit._flatten(arg_params_)
|
||||
param_count_list.append(len(in_vars))
|
||||
else:
|
||||
param_count_list.append(arg_params_ is not None)
|
||||
|
||||
return param_count_list
|
||||
|
||||
|
||||
def _trace_and_get_graph_from_model(model, args):
|
||||
# A basic sanity check: make sure the state_dict keys are the same
|
||||
# before and after running the model. Fail fast!
|
||||
orig_state_dict_keys = torch.jit._unique_state_dict(model).keys()
|
||||
|
||||
# Disable Autocast cache because it replaces kernel's weight and bias
|
||||
# by (undesired) constants.
|
||||
# No perf impact for when there are reused weights since https://github.com/pytorch/pytorch/pull/85665
|
||||
prev_autocast_cache_enabled = torch.is_autocast_cache_enabled()
|
||||
torch.set_autocast_cache_enabled(False)
|
||||
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
|
||||
model,
|
||||
args,
|
||||
strict=False,
|
||||
_force_outplace=False,
|
||||
_return_inputs_states=True,
|
||||
)
|
||||
torch.set_autocast_cache_enabled(prev_autocast_cache_enabled)
|
||||
|
||||
if orig_state_dict_keys != torch.jit._unique_state_dict(model).keys():
|
||||
raise RuntimeError(
|
||||
"state_dict changed after running the tracer; "
|
||||
"something weird is happening in your model!"
|
||||
)
|
||||
|
||||
return trace_graph, torch_out
|
||||
|
||||
|
||||
def _create_jit_graph(
|
||||
model: Union[torch.nn.Module, torch.jit.ScriptFunction], args: Sequence[Any]
|
||||
) -> Tuple[torch.Graph, List["_C.IValue"], Any, Optional[torch.ScriptModule]]:
|
||||
if isinstance(model, (torch.jit.ScriptFunction, torch.jit.ScriptModule)):
|
||||
flattened_args = tuple(torch.jit._flatten(tuple(args))[0])
|
||||
torch_out = None
|
||||
|
||||
if isinstance(model, torch.jit.ScriptModule):
|
||||
try:
|
||||
graph = model.forward.graph # type: ignore[attr-defined]
|
||||
except AttributeError as e:
|
||||
raise RuntimeError("'forward' method must be a script method") from e
|
||||
_C._jit_pass_onnx_function_substitution(graph)
|
||||
freezed_module = _C._freeze_module(
|
||||
typing.cast(_C.ScriptModule, model._c), preserveParameters=True
|
||||
)
|
||||
module, params = _C._jit_onnx_list_model_parameters(freezed_module)
|
||||
method_graph = module._get_method("forward").graph
|
||||
args_params = tuple(args) + tuple(params)
|
||||
param_count_list = _get_param_count_list(method_graph, args_params)
|
||||
in_vars, _ = torch.jit._flatten(args_params)
|
||||
graph = _C._propagate_and_assign_input_shapes(
|
||||
method_graph, tuple(in_vars), param_count_list, False, False
|
||||
)
|
||||
return graph, params, torch_out, module
|
||||
|
||||
# torch.jit.ScriptFunction
|
||||
params = []
|
||||
graph = model.graph
|
||||
_C._jit_pass_onnx_function_substitution(graph)
|
||||
param_count_list = _get_param_count_list(graph, args)
|
||||
graph = _C._propagate_and_assign_input_shapes(
|
||||
graph, flattened_args, param_count_list, False, False
|
||||
)
|
||||
return graph, params, torch_out, None
|
||||
|
||||
graph, torch_out = _trace_and_get_graph_from_model(model, args)
|
||||
_C._jit_pass_onnx_lint(graph)
|
||||
state_dict = torch.jit._unique_state_dict(model)
|
||||
params = list(state_dict.values())
|
||||
graph_inputs = list(graph.inputs())
|
||||
user_input_num = len(graph_inputs) - len(state_dict)
|
||||
param_names = list(state_dict.keys())
|
||||
for i, inp in enumerate(graph_inputs):
|
||||
if i >= user_input_num:
|
||||
inp.setDebugName(param_names[i - user_input_num])
|
||||
_C._jit_pass_onnx_function_substitution(graph)
|
||||
return graph, params, torch_out, None
|
||||
|
||||
|
||||
def inplace_optimize_sym_size_div(gm: torch.fx.GraphModule):
|
||||
def pattern(im, dim, scale):
|
||||
sym_size_int = torch.ops.aten.sym_size.int(im, dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user