mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Remove special handling of torchvision.ops imports in onnx export (#141569)
Fixes #141568 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141569 Approved by: https://github.com/titaiwangms Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
This commit is contained in:
parent
6d204cb5ed
commit
f4187050fe
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torchvision
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||||
from torch.testing._internal import common_utils
|
from torch.testing._internal import common_utils
|
||||||
|
|
@ -125,6 +127,23 @@ class DynamoExporterTest(common_utils.TestCase):
|
||||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
|
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
|
||||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
|
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
|
||||||
|
|
||||||
|
def test_onnx_export_torchvision_ops(self):
|
||||||
|
class VisionModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, *x):
|
||||||
|
out = torchvision.ops.nms(x[0], x[1], x[2])
|
||||||
|
return out
|
||||||
|
|
||||||
|
args = (
|
||||||
|
torch.tensor([[0, 0, 1, 1], [0.5, 0.5, 1, 1]], dtype=torch.float),
|
||||||
|
torch.tensor([0.1, 0.2]),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
|
||||||
|
onnx_testing.assert_onnx_program(onnx_program)
|
||||||
|
|
||||||
# TODO(justinchuby): Test multi-output HOPs
|
# TODO(justinchuby): Test multi-output HOPs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import importlib.util
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import operator
|
import operator
|
||||||
|
|
@ -61,18 +62,9 @@ def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
|
||||||
if namespace == "math":
|
if namespace == "math":
|
||||||
return getattr(math, op_name)
|
return getattr(math, op_name)
|
||||||
if namespace == "torchvision":
|
if namespace == "torchvision":
|
||||||
try:
|
if importlib.util.find_spec("torchvision") is None:
|
||||||
import torchvision.ops # type: ignore[import-untyped]
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
|
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
|
||||||
return None
|
return None
|
||||||
try:
|
|
||||||
return getattr(torchvision.ops, op_name)
|
|
||||||
except AttributeError:
|
|
||||||
logger.warning("Failed to find torchvision op '%s'", qualified_name)
|
|
||||||
return None
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Failed to find torchvision op '%s'", qualified_name)
|
|
||||||
try:
|
try:
|
||||||
op_packet = getattr(getattr(torch.ops, namespace), op_name)
|
op_packet = getattr(getattr(torch.ops, namespace), op_name)
|
||||||
if maybe_overload:
|
if maybe_overload:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user