[ONNX][dynamo_export] Decomposition skips using custom operator (#117314)

A context manager that disables the decomposition of certain ops during dynamo tracing.

The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.

For the time being the decomposition of these ops is otherwise unavoidable.

https://github.com/pytorch/pytorch/issues/116684
https://github.com/pytorch/pytorch/issues/115883

This solution will no longer be required once the issue is resolved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117314
Approved by: https://github.com/justinchuby, https://github.com/malfet
This commit is contained in:
BowenBao 2024-01-17 09:47:01 -08:00 committed by PyTorch MergeBot
parent 92d718aed1
commit 6d9432c44c
5 changed files with 184 additions and 2 deletions

View File

@ -32,7 +32,7 @@ pip_install coloredlogs packaging
retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.17.0.dev20231005006
pip_install -i https://test.pypi.org/simple/ onnx==1.15.0rc2
pip_install onnxscript==0.1.0.dev20231222 --no-deps
pip_install onnxscript==0.1.0.dev20240117 --no-deps
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/

View File

@ -0,0 +1,36 @@
# Owner(s): ["module: onnx"]
from __future__ import annotations
import onnx
import onnx.inliner
import pytorch_test_common
import torch
from torch.testing._internal import common_utils
def assert_op_in_onnx_model(model: onnx.ModelProto, op_type: str):
inlined = onnx.inliner.inline_local_functions(model)
for node in inlined.graph.node:
if node.op_type == op_type:
return
raise AssertionError(f"Op {op_type} not found in model")
class TestDynamoExportDecompSkip(pytorch_test_common.ExportTestCase):
def test_upsample_bilinear2d(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear")
def forward(self, x):
return self.upsample(x)
onnx_program = torch.onnx.dynamo_export(TestModel(), torch.randn(1, 1, 2, 2))
# If decomposition is skipped, the model will contain a Resize op instead of fine grained subgraph.
assert_op_in_onnx_model(onnx_program.model_proto, "Resize")
if __name__ == "__main__":
common_utils.run_tests()

View File

@ -294,6 +294,7 @@ class TestPublicBindings(TestCase):
"torch._inductor.codegen.cuda.cuda_kernel",
"torch.onnx._internal.fx._pass",
"torch.onnx._internal.fx.analysis",
"torch.onnx._internal.fx.decomposition_skip",
"torch.onnx._internal.fx.diagnostics",
"torch.onnx._internal.fx.fx_onnx_interpreter",
"torch.onnx._internal.fx.fx_symbolic_graph_extractor",

View File

@ -1170,7 +1170,13 @@ class Exporter:
self._assert_fake_tensor_mode()
def export(self) -> ONNXProgram:
with self.options.diagnostic_context:
# TODO: Defer `import onnxscript` out of `import torch` path
# https://github.com/pytorch/pytorch/issues/103764
from torch.onnx._internal.fx import decomposition_skip
with self.options.diagnostic_context, decomposition_skip.enable_decomposition_skips(
self.options
):
graph_module = self.options.fx_tracer.generate_fx(
self.options, self.model, self.model_args, self.model_kwargs
)

View File

@ -0,0 +1,139 @@
"""A context manager that disables the decomposition of certain ops during dynamo tracing.
The approach is to temporarily hijack the operator callable with PT2 custom operator.
The custom operator will not be decomposed and will show up as a single node to be exported to ONNX.
For the time being the decomposition of these ops is otherwise unavoidable.
https://github.com/pytorch/pytorch/issues/116684
https://github.com/pytorch/pytorch/issues/115883
This solution will no longer be required once the issue is resolved.
"""
from __future__ import annotations
import abc
import contextlib
from typing import Callable, Sequence, Type
from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found]
nn as torchlib_nn,
)
import torch
from torch._decomp import decompositions
_NEW_OP_NAMESPACE: str = "onnx_export"
"""The namespace for the custom operator."""
class DecompSkip:
op_callable: Callable
"""The original operator callable to skip decomposition."""
onnxscript_function: Callable
"""The ONNXScript function to be registered for exporting the custom operator."""
new_op_name: str
"""The name for the custom operator."""
new_op_schema: str
"""The schema for the custom operator. This should match with the signature of the original operator."""
@classmethod
@abc.abstractmethod
def register(cls, export_options: torch.onnx.ExportOptions):
"""Registers the custom operator and overrides the original operator.
It should do the following steps in order:
1. Register the custom operator.
2. Override the original operator with the replacement callable.
3. Register the ONNXScript function for exporting the custom operator.
"""
...
@classmethod
@abc.abstractmethod
def unregister(cls):
"""Restores the original operator callable."""
...
@classmethod
@abc.abstractmethod
def abstract(cls, *args, **kwargs):
"""An abstract impl (meta kernel) for the operator."""
...
@classmethod
def register_custom_op(cls):
"""Registers the custom operator."""
new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}"
torch.library.define(new_op_qualname, cls.new_op_schema)
torch.library.impl(new_op_qualname, "default", cls.replacement)
torch.library.impl_abstract(new_op_qualname, cls.abstract)
@classmethod
def replacement(cls, *args, **kwargs):
"""A replacement callable for the operator to be hijacked.
This has the same signature and eager behavior as the original operator.
"""
return cls.op_callable(*args, **kwargs)
class UpsampleBilinear2DDecompSkip(DecompSkip):
op_callable = torch._C._nn.upsample_bilinear2d # type: ignore[attr-defined]
onnxscript_function = torchlib_nn.aten_upsample_bilinear2d_vec # type: ignore[attr-defined]
new_op_name = "upsample_bilinear2d"
new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)"
@classmethod
def register(cls, export_options: torch.onnx.ExportOptions):
if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr(
torch.ops.onnx_export, cls.new_op_name
):
cls.register_custom_op()
torch._C._nn.upsample_bilinear2d = torch.ops.onnx_export.upsample_bilinear2d # type: ignore[attr-defined]
if export_options.onnx_registry is None:
export_options.onnx_registry = torch.onnx.OnnxRegistry()
registry = export_options.onnx_registry
registry.register_op(
function=cls.onnxscript_function,
namespace=_NEW_OP_NAMESPACE,
op_name=cls.new_op_name,
)
@classmethod
def unregister(cls):
torch._C._nn.upsample_bilinear2d = cls.op_callable # type: ignore[attr-defined]
@classmethod
def abstract(cls, input, output_size, align_corners, scale_factors):
osize = decompositions.upsample_compute_output_size(
input.size(), output_size, scale_factors
)
return torch.empty(osize, dtype=input.dtype, device=input.device)
_DEFAULT_SKIP_LIST = [
UpsampleBilinear2DDecompSkip,
]
@contextlib.contextmanager
def enable_decomposition_skips(
export_options: torch.onnx.ExportOptions,
skips: Sequence[Type[DecompSkip]] = _DEFAULT_SKIP_LIST,
):
"""A context manager that enables the decomposition skips.
The original operator callables that are otherwise decomposed are replaced with custom operators.
The ONNXScript functions for exporting the custom operators are added to the ONNX registry inside export_options.
"""
try:
for skip in skips:
skip.register(export_options)
yield
finally:
for skip in skips:
skip.unregister()