mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX][dynamo_export] Decomposition skips using custom operator (#117314)
A context manager that disables the decomposition of certain ops during dynamo tracing. The approach is to temporarily hijack the operator callable with PT2 custom operator. The custom operator will not be decomposed and will show up as a single node to be exported to ONNX. For the time being the decomposition of these ops is otherwise unavoidable. https://github.com/pytorch/pytorch/issues/116684 https://github.com/pytorch/pytorch/issues/115883 This solution will no longer be required once the issue is resolved. Pull Request resolved: https://github.com/pytorch/pytorch/pull/117314 Approved by: https://github.com/justinchuby, https://github.com/malfet
This commit is contained in:
parent
92d718aed1
commit
6d9432c44c
|
|
@ -32,7 +32,7 @@ pip_install coloredlogs packaging
|
|||
retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.17.0.dev20231005006
|
||||
|
||||
pip_install -i https://test.pypi.org/simple/ onnx==1.15.0rc2
|
||||
pip_install onnxscript==0.1.0.dev20231222 --no-deps
|
||||
pip_install onnxscript==0.1.0.dev20240117 --no-deps
|
||||
|
||||
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
|
||||
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
|
||||
|
|
|
|||
36
test/onnx/test_fx_to_onnx_decomp_skip.py
Normal file
36
test/onnx/test_fx_to_onnx_decomp_skip.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Owner(s): ["module: onnx"]
|
||||
from __future__ import annotations
|
||||
|
||||
import onnx
|
||||
import onnx.inliner
|
||||
import pytorch_test_common
|
||||
|
||||
import torch
|
||||
from torch.testing._internal import common_utils
|
||||
|
||||
|
||||
def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
|
||||
inlined = onnx.inliner.inline_local_functions(model)
|
||||
for node in inlined.graph.node:
|
||||
if node.op_type == op_type:
|
||||
return
|
||||
raise AssertionError(f"Op {op_type} not found in model")
|
||||
|
||||
|
||||
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
|
||||
def test_upsample_bilinear2d(self):
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
|
||||
|
||||
def forward(self, x):
|
||||
return self.upsample(x)
|
||||
|
||||
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
|
||||
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
|
||||
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
@ -294,6 +294,7 @@ class TestPublicBindings(TestCase):
|
|||
"torch._inductor.codegen.cuda.cuda_kernel",
|
||||
"torch.onnx._internal.fx._pass",
|
||||
"torch.onnx._internal.fx.analysis",
|
||||
"torch.onnx._internal.fx.decomposition_skip",
|
||||
"torch.onnx._internal.fx.diagnostics",
|
||||
"torch.onnx._internal.fx.fx_onnx_interpreter",
|
||||
"torch.onnx._internal.fx.fx_symbolic_graph_extractor",
|
||||
|
|
|
|||
|
|
@ -1170,7 +1170,13 @@ class Exporter:
|
|||
self._assert_fake_tensor_mode()
|
||||
|
||||
def export(self) -> ONNXProgram:
|
||||
with self.options.diagnostic_context:
|
||||
# TODO: Defer `import onnxscript` out of `import torch` path
|
||||
# https://github.com/pytorch/pytorch/issues/103764
|
||||
from torch.onnx._internal.fx import decomposition_skip
|
||||
|
||||
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
|
||||
self.options
|
||||
):
|
||||
graph_module = self.options.fx_tracer.generate_fx(
|
||||
self.options, self.model, self.model_args, self.model_kwargs
|
||||
)
|
||||
|
|
|
|||
139
torch/onnx/_internal/fx/decomposition_skip.py
Normal file
139
torch/onnx/_internal/fx/decomposition_skip.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""A context manager that disables the decomposition of certain ops during dynamo tracing.
|
||||
|
||||
The approach is to temporarily hijack the operator callable with PT2 custom operator.
|
||||
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.
|
||||
|
||||
For the time being the decomposition of these ops is otherwise unavoidable.
|
||||
|
||||
https://github.com/pytorch/pytorch/issues/116684
|
||||
https://github.com/pytorch/pytorch/issues/115883
|
||||
|
||||
This solution will no longer be required once the issue is resolved.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
|
||||
from typing import Callable, Sequence, Type
|
||||
|
||||
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
|
||||
nn as torchlib_nn,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch._decomp import decompositions
|
||||
|
||||
_NEW_OP_NAMESPACE: str = "onnx_export"
|
||||
"""The namespace for the custom operator."""
|
||||
|
||||
|
||||
class DecompSkip:
|
||||
op_callable: Callable
|
||||
"""The original operator callable to skip decomposition."""
|
||||
onnxscript_function: Callable
|
||||
"""The ONNXScript function to be registered for exporting the custom operator."""
|
||||
|
||||
new_op_name: str
|
||||
"""The name for the custom operator."""
|
||||
new_op_schema: str
|
||||
"""The schema for the custom operator. This should match with the signature of the original operator."""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def register(cls, export_options: torch.onnx.ExportOptions):
|
||||
"""Registers the custom operator and overrides the original operator.
|
||||
|
||||
It should do the following steps in order:
|
||||
|
||||
1. Register the custom operator.
|
||||
2. Override the original operator with the replacement callable.
|
||||
3. Register the ONNXScript function for exporting the custom operator.
|
||||
"""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def unregister(cls):
|
||||
"""Restores the original operator callable."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def abstract(cls, *args, **kwargs):
|
||||
"""An abstract impl (meta kernel) for the operator."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def register_custom_op(cls):
|
||||
"""Registers the custom operator."""
|
||||
new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}"
|
||||
torch.library.define(new_op_qualname, cls.new_op_schema)
|
||||
torch.library.impl(new_op_qualname, "default", cls.replacement)
|
||||
torch.library.impl_abstract(new_op_qualname, cls.abstract)
|
||||
|
||||
@classmethod
|
||||
def replacement(cls, *args, **kwargs):
|
||||
"""A replacement callable for the operator to be hijacked.
|
||||
|
||||
This has the same signature and eager behavior as the original operator.
|
||||
"""
|
||||
return cls.op_callable(*args, **kwargs)
|
||||
|
||||
|
||||
class UpsampleBilinear2DDecompSkip(DecompSkip):
|
||||
op_callable = torch._C._nn.upsample_bilinear2d # type: ignore[attr-defined]
|
||||
onnxscript_function = torchlib_nn.aten_upsample_bilinear2d_vec # type: ignore[attr-defined]
|
||||
new_op_name = "upsample_bilinear2d"
|
||||
new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)"
|
||||
|
||||
@classmethod
|
||||
def register(cls, export_options: torch.onnx.ExportOptions):
|
||||
if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
|
||||
torch.ops.onnx_export, cls.new_op_name
|
||||
):
|
||||
cls.register_custom_op()
|
||||
torch._C._nn.upsample_bilinear2d = torch.ops.onnx_export.upsample_bilinear2d # type: ignore[attr-defined]
|
||||
if export_options.onnx_registry is None:
|
||||
export_options.onnx_registry = torch.onnx.OnnxRegistry()
|
||||
registry = export_options.onnx_registry
|
||||
registry.register_op(
|
||||
function=cls.onnxscript_function,
|
||||
namespace=_NEW_OP_NAMESPACE,
|
||||
op_name=cls.new_op_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unregister(cls):
|
||||
torch._C._nn.upsample_bilinear2d = cls.op_callable # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def abstract(cls, input, output_size, align_corners, scale_factors):
|
||||
osize = decompositions.upsample_compute_output_size(
|
||||
input.size(), output_size, scale_factors
|
||||
)
|
||||
return torch.empty(osize, dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
_DEFAULT_SKIP_LIST = [
|
||||
UpsampleBilinear2DDecompSkip,
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable_decomposition_skips(
|
||||
export_options: torch.onnx.ExportOptions,
|
||||
skips: Sequence[Type[DecompSkip]] = _DEFAULT_SKIP_LIST,
|
||||
):
|
||||
"""A context manager that enables the decomposition skips.
|
||||
|
||||
The original operator callables that are otherwise decomposed are replaced with custom operators.
|
||||
The ONNXScript functions for exporting the custom operators are added to the ONNX registry inside export_options.
|
||||
"""
|
||||
try:
|
||||
for skip in skips:
|
||||
skip.register(export_options)
|
||||
yield
|
||||
finally:
|
||||
for skip in skips:
|
||||
skip.unregister()
|
||||
Loading…
Reference in New Issue
Block a user