[ONNX] Remove special handling of torchvision.ops imports in onnx export (#141569)

Fixes #141568

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141569
Approved by: https://github.com/titaiwangms

Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
This commit is contained in:
Bludator 2024-11-28 18:05:38 +00:00 committed by PyTorch MergeBot
parent 6d204cb5ed
commit f4187050fe
2 changed files with 21 additions and 10 deletions

View File

@ -3,6 +3,8 @@
from __future__ import annotations from __future__ import annotations
import torchvision
import torch import torch
from torch.onnx._internal.exporter import _testing as onnx_testing from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
@ -125,6 +127,23 @@ class DynamoExporterTest(common_utils.TestCase):
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),)) onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),)) onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
def test_onnx_export_torchvision_ops(self):
class VisionModel(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, *x):
out = torchvision.ops.nms(x[0], x[1], x[2])
return out
args = (
torch.tensor([[0, 0, 1, 1], [0.5, 0.5, 1, 1]], dtype=torch.float),
torch.tensor([0.1, 0.2]),
0,
)
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
onnx_testing.assert_onnx_program(onnx_program)
# TODO(justinchuby): Test multi-output HOPs # TODO(justinchuby): Test multi-output HOPs

View File

@ -13,6 +13,7 @@ https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import importlib.util
import logging import logging
import math import math
import operator import operator
@ -61,18 +62,9 @@ def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
if namespace == "math": if namespace == "math":
return getattr(math, op_name) return getattr(math, op_name)
if namespace == "torchvision": if namespace == "torchvision":
try: if importlib.util.find_spec("torchvision") is None:
import torchvision.ops # type: ignore[import-untyped]
except ImportError:
logger.warning("torchvision is not installed. Skipping %s", qualified_name) logger.warning("torchvision is not installed. Skipping %s", qualified_name)
return None return None
try:
return getattr(torchvision.ops, op_name)
except AttributeError:
logger.warning("Failed to find torchvision op '%s'", qualified_name)
return None
except Exception:
logger.exception("Failed to find torchvision op '%s'", qualified_name)
try: try:
op_packet = getattr(getattr(torch.ops, namespace), op_name) op_packet = getattr(getattr(torch.ops, namespace), op_name)
if maybe_overload: if maybe_overload: