mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Remove special handling of torchvision.ops imports in onnx export (#141569)
Fixes #141568 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141569 Approved by: https://github.com/titaiwangms Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Ti-Tai Wang <titaiwang@microsoft.com>
This commit is contained in:
parent
6d204cb5ed
commit
f4187050fe
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import torchvision
|
||||
|
||||
import torch
|
||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||
from torch.testing._internal import common_utils
|
||||
|
|
@ -125,6 +127,23 @@ class DynamoExporterTest(common_utils.TestCase):
|
|||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([0, 0]),))
|
||||
onnx_testing.assert_onnx_program(onnx_program, args=(torch.tensor([43, 43]),))
|
||||
|
||||
def test_onnx_export_torchvision_ops(self):
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, *x):
|
||||
out = torchvision.ops.nms(x[0], x[1], x[2])
|
||||
return out
|
||||
|
||||
args = (
|
||||
torch.tensor([[0, 0, 1, 1], [0.5, 0.5, 1, 1]], dtype=torch.float),
|
||||
torch.tensor([0.1, 0.2]),
|
||||
0,
|
||||
)
|
||||
onnx_program = torch.onnx.export(VisionModel(), args, dynamo=True)
|
||||
onnx_testing.assert_onnx_program(onnx_program)
|
||||
|
||||
# TODO(justinchuby): Test multi-output HOPs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ https://github.com/pytorch/pytorch/blob/6aa5bb1a76dee8112f1a9e7c194c790b5cdc6462
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import importlib.util
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
|
|
@ -61,18 +62,9 @@ def _get_overload(qualified_name: str) -> torch._ops.OpOverload | None:
|
|||
if namespace == "math":
|
||||
return getattr(math, op_name)
|
||||
if namespace == "torchvision":
|
||||
try:
|
||||
import torchvision.ops # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
if importlib.util.find_spec("torchvision") is None:
|
||||
logger.warning("torchvision is not installed. Skipping %s", qualified_name)
|
||||
return None
|
||||
try:
|
||||
return getattr(torchvision.ops, op_name)
|
||||
except AttributeError:
|
||||
logger.warning("Failed to find torchvision op '%s'", qualified_name)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Failed to find torchvision op '%s'", qualified_name)
|
||||
try:
|
||||
op_packet = getattr(getattr(torch.ops, namespace), op_name)
|
||||
if maybe_overload:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user