mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Create onnx_symbolic (#148905)
In the old exporter we allow users to define a symbolic() method to bypass JIT tracing for a block of logic. We can allow users to do similar things by creating symbolic ops at export. This PR implements `torch.onnx.ops.symbolic` and `torch.onnx.ops.symbolic_multi_out` to allow users to create onnx nodes symbolically with pt2 & fx. The custom pytorch ops were designed such that the attributes are encoded to be part of a valid fx op. Users provide shape and dtype for the meta function to produce the currect fake tensor during export. An example is  Pull Request resolved: https://github.com/pytorch/pytorch/pull/148905 Approved by: https://github.com/titaiwangms
This commit is contained in:
parent
d80a70b58a
commit
010963032c
|
|
@ -88,6 +88,7 @@ also be interested in reading our `development wiki <https://github.com/pytorch/
|
||||||
:hidden:
|
:hidden:
|
||||||
|
|
||||||
onnx_dynamo
|
onnx_dynamo
|
||||||
|
onnx_ops
|
||||||
onnx_verification
|
onnx_verification
|
||||||
onnx_dynamo_onnxruntime_backend
|
onnx_dynamo_onnxruntime_backend
|
||||||
onnx_torchscript
|
onnx_torchscript
|
||||||
|
|
|
||||||
11
docs/source/onnx_ops.rst
Normal file
11
docs/source/onnx_ops.rst
Normal file
|
|
@ -0,0 +1,11 @@
|
||||||
|
torch.onnx.ops
|
||||||
|
==============
|
||||||
|
|
||||||
|
.. automodule:: torch.onnx.ops
|
||||||
|
|
||||||
|
Operators
|
||||||
|
---------
|
||||||
|
|
||||||
|
.. autofunction:: torch.onnx.ops.symbolic
|
||||||
|
|
||||||
|
.. autofunction:: torch.onnx.ops.symbolic_multi_out
|
||||||
418
test/onnx/ops/test_ops.py
Normal file
418
test/onnx/ops/test_ops.py
Normal file
|
|
@ -0,0 +1,418 @@
|
||||||
|
# Owner(s): ["module: onnx"]
|
||||||
|
"""Test torch.onnx.ops."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from onnxscript import ir
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx.ops import _symbolic_impl
|
||||||
|
from torch.testing._internal import common_utils
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaTest(common_utils.TestCase):
|
||||||
|
def test_symbolic_has_correct_schema(self):
|
||||||
|
torch.library.opcheck(
|
||||||
|
_symbolic_impl._symbolic,
|
||||||
|
([torch.tensor(1)], "CustomOp", 1),
|
||||||
|
dict(
|
||||||
|
shape=[
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
attr_keys=["key"],
|
||||||
|
attr_types=["i"],
|
||||||
|
attr_pos=[(0, 1)],
|
||||||
|
attr_ints=[1],
|
||||||
|
attr_floats=[1.0],
|
||||||
|
attr_strs=["attr"],
|
||||||
|
metadata_props_keys=["meta_key"],
|
||||||
|
metadata_props_values=["meta_value"],
|
||||||
|
domain="custom_domain",
|
||||||
|
version=42,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty inputs
|
||||||
|
torch.library.opcheck(
|
||||||
|
_symbolic_impl._symbolic,
|
||||||
|
([], "CustomOp", 1),
|
||||||
|
dict(
|
||||||
|
shape=[
|
||||||
|
1,
|
||||||
|
],
|
||||||
|
attr_keys=[],
|
||||||
|
attr_types=[],
|
||||||
|
attr_pos=[],
|
||||||
|
attr_ints=[],
|
||||||
|
attr_floats=[],
|
||||||
|
attr_strs=[],
|
||||||
|
metadata_props_keys=[],
|
||||||
|
metadata_props_values=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_has_correct_schema(self):
|
||||||
|
torch.library.opcheck(
|
||||||
|
_symbolic_impl._symbolic_multi_out,
|
||||||
|
([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
|
||||||
|
dict(
|
||||||
|
shapes=[[1, 2], [42], []],
|
||||||
|
attr_keys=["key"],
|
||||||
|
attr_types=["i"],
|
||||||
|
attr_pos=[(0, 1)],
|
||||||
|
attr_ints=[1],
|
||||||
|
attr_floats=[1.0],
|
||||||
|
attr_strs=["attr"],
|
||||||
|
metadata_props_keys=["meta_key"],
|
||||||
|
metadata_props_values=["meta_value"],
|
||||||
|
domain="",
|
||||||
|
version=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Empty inputs
|
||||||
|
torch.library.opcheck(
|
||||||
|
_symbolic_impl._symbolic_multi_out,
|
||||||
|
([], "CustomMultiOutOp", []),
|
||||||
|
dict(
|
||||||
|
shapes=[],
|
||||||
|
attr_keys=[],
|
||||||
|
attr_types=[],
|
||||||
|
attr_pos=[],
|
||||||
|
attr_ints=[],
|
||||||
|
attr_floats=[],
|
||||||
|
attr_strs=[],
|
||||||
|
metadata_props_keys=[],
|
||||||
|
metadata_props_values=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SymbolicOpsTest(common_utils.TestCase):
|
||||||
|
def test_symbolic_accepts_valid_inputs(self):
|
||||||
|
output = torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtype=torch.float32,
|
||||||
|
shape=[1, 2, 3],
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
self.assertEqual(output.shape, torch.Size([1, 2, 3]))
|
||||||
|
self.assertEqual(output.dtype, torch.float32)
|
||||||
|
self.assertEqual(output.device, torch.device("cpu"))
|
||||||
|
|
||||||
|
def test_symbolic_accepts_valid_inputs_empty_shape(self):
|
||||||
|
output = torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtype=torch.float32,
|
||||||
|
shape=[],
|
||||||
|
)
|
||||||
|
self.assertEqual(output.shape, torch.Size([]))
|
||||||
|
|
||||||
|
def test_symbolic_accepts_valid_inputs_integer_types(self):
|
||||||
|
output = torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtype=1, # 1 is float32 in ONNX
|
||||||
|
shape=[42],
|
||||||
|
)
|
||||||
|
self.assertEqual(output.dtype, torch.float32)
|
||||||
|
|
||||||
|
def test_symbolic_accepts_valid_inputs_int4_type(self):
|
||||||
|
output = torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtype=22, # 22 is INT4 in ONNX
|
||||||
|
shape=[42],
|
||||||
|
)
|
||||||
|
# We use torch uint8 for int4
|
||||||
|
self.assertEqual(output.dtype, torch.uint8)
|
||||||
|
|
||||||
|
def test_symbolic_is_exportable(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x,),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtype=x.dtype,
|
||||||
|
shape=[1, 2, 3],
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_program = torch.onnx.export(
|
||||||
|
Model(), (torch.tensor(1),), dynamo=True, verbose=False
|
||||||
|
)
|
||||||
|
assert onnx_program is not None
|
||||||
|
node = onnx_program.model.graph.node(0)
|
||||||
|
self.assertEqual(node.op_type, "CustomOp")
|
||||||
|
self.assertEqual(node.domain, "custom_domain")
|
||||||
|
attributes = node.attributes
|
||||||
|
self.assertEqual(
|
||||||
|
attributes,
|
||||||
|
dict(
|
||||||
|
int_key=ir.AttrInt64("int_key", 1),
|
||||||
|
float_key=ir.AttrFloat32("float_key", 1.0),
|
||||||
|
str_key=ir.AttrString("str_key", "attr"),
|
||||||
|
bool_key=ir.AttrInt64("bool_key", 1),
|
||||||
|
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
|
||||||
|
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
|
||||||
|
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
|
||||||
|
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
|
||||||
|
outputs = node.outputs
|
||||||
|
self.assertEqual(list(outputs[0].shape), [1, 2, 3])
|
||||||
|
self.assertEqual(outputs[0].dtype, ir.DataType.INT64)
|
||||||
|
|
||||||
|
def test_symbolic_preserves_dynamic_shapes(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x, y),
|
||||||
|
dtype=x.dtype,
|
||||||
|
shape=[*x.shape, *y.shape],
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_program = torch.onnx.export(
|
||||||
|
Model(),
|
||||||
|
(torch.zeros(2, 3), torch.zeros(1, 2)),
|
||||||
|
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
|
||||||
|
dynamo=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
assert onnx_program is not None
|
||||||
|
node = onnx_program.model.graph.node(0)
|
||||||
|
self.assertEqual(node.op_type, "CustomOp")
|
||||||
|
self.assertEqual(node.domain, "custom_domain")
|
||||||
|
inputs = onnx_program.model.graph.inputs
|
||||||
|
self.assertEqual(str(inputs[0].shape[0]), "batch")
|
||||||
|
self.assertEqual(inputs[0].shape[1], 3)
|
||||||
|
self.assertEqual(inputs[1].shape[0], 1)
|
||||||
|
self.assertEqual(str(inputs[1].shape[1]), "something_else")
|
||||||
|
outputs = node.outputs
|
||||||
|
self.assertEqual(str(outputs[0].shape[0]), "batch")
|
||||||
|
self.assertEqual(outputs[0].shape[1], 3)
|
||||||
|
self.assertEqual(outputs[0].shape[2], 1)
|
||||||
|
self.assertEqual(str(outputs[0].shape[3]), "something_else")
|
||||||
|
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_accepts_valid_inputs(self):
|
||||||
|
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomMultiOutOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtypes=(
|
||||||
|
1, # 1 is float32 in ONNX
|
||||||
|
torch.int32,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
shapes=([1, 2], [42], []),
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
self.assertEqual(len(outputs), 3)
|
||||||
|
self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
|
||||||
|
self.assertEqual(outputs[0].dtype, torch.float32)
|
||||||
|
self.assertEqual(outputs[1].shape, torch.Size([42]))
|
||||||
|
self.assertEqual(outputs[1].dtype, torch.int32)
|
||||||
|
self.assertEqual(outputs[2].shape, torch.Size([]))
|
||||||
|
self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
|
||||||
|
self.assertEqual(outputs[0].device, torch.device("cpu"))
|
||||||
|
self.assertEqual(outputs[1].device, torch.device("cpu"))
|
||||||
|
self.assertEqual(outputs[2].device, torch.device("cpu"))
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
|
||||||
|
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtypes=(torch.float32,),
|
||||||
|
shapes=[[]],
|
||||||
|
)
|
||||||
|
self.assertEqual(outputs[0].shape, torch.Size([]))
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
|
||||||
|
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtypes=(1,), # 1 is float32 in ONNX
|
||||||
|
shapes=[[42]],
|
||||||
|
)
|
||||||
|
self.assertEqual(outputs[0].dtype, torch.float32)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
|
||||||
|
outputs = torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dtypes=(22,), # 22 is INT4 in ONNX
|
||||||
|
shapes=[[42]],
|
||||||
|
)
|
||||||
|
# We use torch uint8 for int4
|
||||||
|
self.assertEqual(outputs[0].dtype, torch.uint8)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_is_exportable(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x,),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
|
||||||
|
shapes=([1, 2], [42], []),
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_program = torch.onnx.export(
|
||||||
|
Model(), (torch.tensor(1),), dynamo=True, verbose=False
|
||||||
|
)
|
||||||
|
assert onnx_program is not None
|
||||||
|
node = onnx_program.model.graph.node(0)
|
||||||
|
self.assertEqual(node.op_type, "CustomOp")
|
||||||
|
self.assertEqual(node.domain, "custom_domain")
|
||||||
|
attributes = node.attributes
|
||||||
|
self.assertEqual(
|
||||||
|
attributes,
|
||||||
|
dict(
|
||||||
|
int_key=ir.AttrInt64("int_key", 1),
|
||||||
|
float_key=ir.AttrFloat32("float_key", 1.0),
|
||||||
|
str_key=ir.AttrString("str_key", "attr"),
|
||||||
|
bool_key=ir.AttrInt64("bool_key", 1),
|
||||||
|
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
|
||||||
|
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
|
||||||
|
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
|
||||||
|
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
|
||||||
|
outputs = node.outputs
|
||||||
|
self.assertEqual(list(outputs[0].shape), [1, 2])
|
||||||
|
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||||
|
self.assertEqual(list(outputs[1].shape), [42])
|
||||||
|
self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
|
||||||
|
self.assertEqual(list(outputs[2].shape), [])
|
||||||
|
self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_preserves_dynamic_shapes(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x, y),
|
||||||
|
dtypes=(x.dtype, 22), # 22 is INT4
|
||||||
|
shapes=[[*x.shape, *y.shape], [42]],
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
onnx_program = torch.onnx.export(
|
||||||
|
Model(),
|
||||||
|
(torch.zeros(2, 3), torch.zeros(1, 2)),
|
||||||
|
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
|
||||||
|
dynamo=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
|
assert onnx_program is not None
|
||||||
|
node = onnx_program.model.graph.node(0)
|
||||||
|
self.assertEqual(node.op_type, "CustomOp")
|
||||||
|
self.assertEqual(node.domain, "custom_domain")
|
||||||
|
inputs = onnx_program.model.graph.inputs
|
||||||
|
self.assertEqual(str(inputs[0].shape[0]), "batch")
|
||||||
|
self.assertEqual(inputs[0].shape[1], 3)
|
||||||
|
self.assertEqual(inputs[1].shape[0], 1)
|
||||||
|
self.assertEqual(str(inputs[1].shape[1]), "something_else")
|
||||||
|
outputs = node.outputs
|
||||||
|
self.assertEqual(str(outputs[0].shape[0]), "batch")
|
||||||
|
self.assertEqual(outputs[0].shape[1], 3)
|
||||||
|
self.assertEqual(outputs[0].shape[2], 1)
|
||||||
|
self.assertEqual(str(outputs[0].shape[3]), "something_else")
|
||||||
|
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
|
||||||
|
self.assertEqual(list(outputs[1].shape), [42])
|
||||||
|
self.assertEqual(outputs[1].dtype, ir.DataType.INT4)
|
||||||
|
|
||||||
|
def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomMultiOutOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtypes=(torch.float32, torch.int32),
|
||||||
|
shapes=([1, 2], [42], []),
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
torch.onnx.ops.symbolic_multi_out(
|
||||||
|
"custom_domain::CustomMultiOutOp",
|
||||||
|
(torch.tensor(1),),
|
||||||
|
dict(
|
||||||
|
int_key=1,
|
||||||
|
float_key=1.0,
|
||||||
|
str_key="attr",
|
||||||
|
bool_key=True,
|
||||||
|
list_int_key=[1, 2],
|
||||||
|
list_float_key=[1.0, 2.0],
|
||||||
|
list_str_key=["attr1", "attr2"],
|
||||||
|
list_bool_key=[True, False],
|
||||||
|
),
|
||||||
|
dtypes=(torch.float32,),
|
||||||
|
shapes=([1, 2], [42]),
|
||||||
|
version=1,
|
||||||
|
metadata_props={"meta_key": "meta_value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
common_utils.run_tests()
|
||||||
|
|
@ -4,9 +4,10 @@ from __future__ import annotations
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Modules
|
# Modules
|
||||||
|
"errors",
|
||||||
|
"ops",
|
||||||
"symbolic_helper",
|
"symbolic_helper",
|
||||||
"utils",
|
"utils",
|
||||||
"errors",
|
|
||||||
# All opsets
|
# All opsets
|
||||||
"symbolic_caffe2",
|
"symbolic_caffe2",
|
||||||
"symbolic_opset7",
|
"symbolic_opset7",
|
||||||
|
|
@ -52,7 +53,6 @@ from typing import Any, Callable, TYPE_CHECKING
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import _C
|
|
||||||
from torch._C import _onnx as _C_onnx
|
from torch._C import _onnx as _C_onnx
|
||||||
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
|
||||||
|
|
||||||
|
|
@ -77,6 +77,7 @@ from .utils import (
|
||||||
|
|
||||||
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
from . import ( # usort: skip. Keep the order instead of sorting lexicographically
|
||||||
errors,
|
errors,
|
||||||
|
ops,
|
||||||
symbolic_caffe2,
|
symbolic_caffe2,
|
||||||
symbolic_helper,
|
symbolic_helper,
|
||||||
symbolic_opset7,
|
symbolic_opset7,
|
||||||
|
|
|
||||||
|
|
@ -200,27 +200,33 @@ def _set_shape_type(
|
||||||
| tuple[torch.Tensor],
|
| tuple[torch.Tensor],
|
||||||
complex_to_float: bool,
|
complex_to_float: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Consider using meta["tensor_meta"] for this? Would it be faster?
|
|
||||||
if isinstance(meta_val, tuple):
|
if isinstance(meta_val, tuple):
|
||||||
logger.warning("Setting shape and type of tensors is not supported yet")
|
logger.warning("Setting shape and type of tensors is not supported yet")
|
||||||
if isinstance(meta_val, torch.Tensor):
|
if isinstance(meta_val, torch.Tensor):
|
||||||
# FIXME: Consider shape for complex values
|
|
||||||
dims = []
|
dims = []
|
||||||
for dim in meta_val.shape:
|
for dim in meta_val.shape:
|
||||||
if isinstance(dim, int):
|
if isinstance(dim, int):
|
||||||
dims.append(dim)
|
dims.append(dim)
|
||||||
else:
|
else:
|
||||||
dims.append(str(dim.node))
|
dims.append(str(dim.node))
|
||||||
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
|
|
||||||
if complex_to_float:
|
# If the dtype is set already (e.g. by the onnx_symbolic ops),
|
||||||
if meta_val.dtype == torch.complex64:
|
# we don't need to set it again.
|
||||||
value.dtype = ir.DataType.FLOAT
|
#
|
||||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
# When a user specifies complex in onnx_symbolic, we consider that to
|
||||||
dims.append(2)
|
# be the intention even though non of the ONNX ops deals with complex values.
|
||||||
elif meta_val.dtype == torch.complex128:
|
# In this case, we don't change the dtype or the shape of the tensor.
|
||||||
value.dtype = ir.DataType.DOUBLE
|
if value.dtype is None:
|
||||||
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
value.dtype = _torch_dtype_to_onnx_dtype(meta_val.dtype)
|
||||||
dims.append(2)
|
if complex_to_float:
|
||||||
|
if meta_val.dtype == torch.complex64:
|
||||||
|
value.dtype = ir.DataType.FLOAT
|
||||||
|
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||||
|
dims.append(2)
|
||||||
|
elif meta_val.dtype == torch.complex128:
|
||||||
|
value.dtype = ir.DataType.DOUBLE
|
||||||
|
# Add 2 as the last dimension if the tensor is complex to hold the real/imag parts
|
||||||
|
dims.append(2)
|
||||||
|
|
||||||
value.shape = ir.Shape(dims)
|
value.shape = ir.Shape(dims)
|
||||||
elif isinstance(meta_val, (int, torch.SymInt)):
|
elif isinstance(meta_val, (int, torch.SymInt)):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["core", "hop"]
|
__all__ = ["core", "hop", "symbolic"]
|
||||||
|
|
||||||
from torch.onnx._internal.exporter._torchlib.ops import core, hop
|
from torch.onnx._internal.exporter._torchlib.ops import core, hop, symbolic
|
||||||
|
|
|
||||||
149
torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py
Normal file
149
torch/onnx/_internal/exporter/_torchlib/ops/symbolic.py
Normal file
|
|
@ -0,0 +1,149 @@
|
||||||
|
"""Implementation for higher-order operators."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from onnxscript.ir import convenience as ir_convenience
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
|
||||||
|
from torch.onnx._internal.exporter import _core
|
||||||
|
from torch.onnx._internal.exporter._torchlib._torchlib_registry import onnx_impl
|
||||||
|
from torch.onnx.ops import _symbolic_impl
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
def _call_symbolic_op(
|
||||||
|
op_type: str,
|
||||||
|
domain: str,
|
||||||
|
args: Sequence[ir.Value | None],
|
||||||
|
kwargs: dict[str, int | float | str | bool | list[int] | list[float] | list[str]],
|
||||||
|
dtypes: Sequence[int],
|
||||||
|
version: int | None,
|
||||||
|
metadata_props: dict[str, str] | None,
|
||||||
|
) -> Sequence[ir.Value]:
|
||||||
|
"""Call an operator with the given arguments and keyword arguments.
|
||||||
|
|
||||||
|
Arguments are always inputs, while keyword arguments are attributes.
|
||||||
|
"""
|
||||||
|
# This is a wrapper around the IR node creation that hooks into the _builder.OpRecorder
|
||||||
|
# tracer so that all nodes created are recorded the same way as if we were to use
|
||||||
|
# onnxscript ops directly.
|
||||||
|
|
||||||
|
assert _core.current_tracer is not None
|
||||||
|
tracer = _core.current_tracer
|
||||||
|
|
||||||
|
inputs = list(args)
|
||||||
|
|
||||||
|
# If final inputs are None, strip them from the node inputs
|
||||||
|
for input in reversed(inputs):
|
||||||
|
if input is not None:
|
||||||
|
break
|
||||||
|
inputs.pop()
|
||||||
|
|
||||||
|
# Construct and filter out None attributes
|
||||||
|
attributes = [
|
||||||
|
attr
|
||||||
|
for attr in ir_convenience.convert_attributes(kwargs) # type: ignore[arg-type]
|
||||||
|
if attr.value is not None # type: ignore[union-attr]
|
||||||
|
]
|
||||||
|
tracer.nodes.append(
|
||||||
|
node := ir.Node(
|
||||||
|
domain,
|
||||||
|
op_type,
|
||||||
|
inputs=inputs,
|
||||||
|
attributes=attributes,
|
||||||
|
num_outputs=len(dtypes),
|
||||||
|
version=version,
|
||||||
|
metadata_props=metadata_props,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Set the dtypes for the outputs. We set them here because the graph builder
|
||||||
|
# Uses PyTorch types which are sometimes inaccurate when they are ONNX only
|
||||||
|
# types like float4e2m1.
|
||||||
|
for value, dtype in zip(node.outputs, dtypes):
|
||||||
|
value.dtype = ir.DataType(dtype)
|
||||||
|
# The shape is set by the graph builder. We don't need to set it here.
|
||||||
|
return node.outputs
|
||||||
|
|
||||||
|
|
||||||
|
@onnx_impl(torch.ops.onnx_symbolic._symbolic.default, no_compile=True)
|
||||||
|
def onnx_symbolic_symbolic(
|
||||||
|
inputs: Sequence[ir.Value | None],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtype: int,
|
||||||
|
*,
|
||||||
|
shape: Sequence[int | ir.Value],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: int | None = None,
|
||||||
|
) -> ir.Value:
|
||||||
|
del shape # Unused. The shapes are set by the graph builder
|
||||||
|
encoded = _symbolic_impl.EncodedAttrs(
|
||||||
|
attr_keys=list(attr_keys),
|
||||||
|
attr_types=list(attr_types),
|
||||||
|
attr_pos=list(attr_pos),
|
||||||
|
attr_ints=list(attr_ints),
|
||||||
|
attr_floats=list(attr_floats),
|
||||||
|
attr_strs=list(attr_strs),
|
||||||
|
)
|
||||||
|
attrs = encoded.to_dict()
|
||||||
|
return _call_symbolic_op(
|
||||||
|
op_type,
|
||||||
|
domain,
|
||||||
|
inputs,
|
||||||
|
attrs,
|
||||||
|
dtypes=[onnx_dtype],
|
||||||
|
version=version,
|
||||||
|
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
|
||||||
|
@onnx_impl(torch.ops.onnx_symbolic._symbolic_multi_out.default, no_compile=True)
|
||||||
|
def onnx_symbolic_symbolic_multi_out(
|
||||||
|
inputs: Sequence[ir.Value | None],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtypes: Sequence[int],
|
||||||
|
*,
|
||||||
|
shapes: Sequence[Sequence[int | ir.Value]],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: int | None = None,
|
||||||
|
) -> Sequence[ir.Value]:
|
||||||
|
del shapes # Unused. The shapes are set by the graph builder
|
||||||
|
encoded = _symbolic_impl.EncodedAttrs(
|
||||||
|
attr_keys=list(attr_keys),
|
||||||
|
attr_types=list(attr_types),
|
||||||
|
attr_pos=list(attr_pos),
|
||||||
|
attr_ints=list(attr_ints),
|
||||||
|
attr_floats=list(attr_floats),
|
||||||
|
attr_strs=list(attr_strs),
|
||||||
|
)
|
||||||
|
attrs = encoded.to_dict()
|
||||||
|
return _call_symbolic_op(
|
||||||
|
op_type,
|
||||||
|
domain,
|
||||||
|
inputs,
|
||||||
|
attrs,
|
||||||
|
dtypes=onnx_dtypes,
|
||||||
|
version=version,
|
||||||
|
metadata_props=dict(zip(metadata_props_keys, metadata_props_values)),
|
||||||
|
)
|
||||||
243
torch/onnx/ops/__init__.py
Normal file
243
torch/onnx/ops/__init__.py
Normal file
|
|
@ -0,0 +1,243 @@
|
||||||
|
"""ONNX operators as native torch.fx operators.
|
||||||
|
|
||||||
|
This module provides a set of functions to create ONNX operators in the FX graph
|
||||||
|
which are exportable to ONNX.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.onnx.ops import _symbolic_impl
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/onnx/onnx/blob/f542e1f06699ea7e1db5f62af53355b64338c723/onnx/onnx.proto#L597
|
||||||
|
_TORCH_DTYPE_TO_ONNX_DTYPE = {
|
||||||
|
torch.float32: 1, # FLOAT
|
||||||
|
torch.uint8: 2, # UINT8
|
||||||
|
torch.int8: 3, # INT8
|
||||||
|
torch.uint16: 4, # UINT16
|
||||||
|
torch.int16: 5, # INT16
|
||||||
|
torch.int32: 6, # INT32
|
||||||
|
torch.int64: 7, # INT64
|
||||||
|
str: 8, # STRING
|
||||||
|
torch.bool: 9, # BOOL
|
||||||
|
torch.float16: 10, # FLOAT16
|
||||||
|
torch.double: 11, # DOUBLE
|
||||||
|
torch.uint32: 12, # UINT32
|
||||||
|
torch.uint64: 13, # UINT64
|
||||||
|
torch.complex64: 14, # COMPLEX64
|
||||||
|
torch.complex128: 15, # COMPLEX128
|
||||||
|
torch.bfloat16: 16, # BFLOAT16
|
||||||
|
torch.float8_e4m3fn: 17, # FLOAT8E4M3FN
|
||||||
|
torch.float8_e4m3fnuz: 18, # FLOAT8E4M3FNUZ
|
||||||
|
torch.float8_e5m2: 19, # FLOAT8E5M2
|
||||||
|
torch.float8_e5m2fnuz: 20, # FLOAT8E5M2FNUZ
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_domain_op_type(domain_op: str) -> tuple[str, str]:
|
||||||
|
splitted = domain_op.split("::", 1)
|
||||||
|
if len(splitted) == 1:
|
||||||
|
domain = ""
|
||||||
|
op_type = splitted[0]
|
||||||
|
else:
|
||||||
|
domain = splitted[0]
|
||||||
|
op_type = splitted[1]
|
||||||
|
return domain, op_type
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic(
|
||||||
|
domain_op: str,
|
||||||
|
/,
|
||||||
|
inputs: Sequence[torch.Tensor],
|
||||||
|
attrs: dict[
|
||||||
|
str,
|
||||||
|
int
|
||||||
|
| float
|
||||||
|
| str
|
||||||
|
| bool
|
||||||
|
| Sequence[int]
|
||||||
|
| Sequence[float]
|
||||||
|
| Sequence[str]
|
||||||
|
| Sequence[bool],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
*,
|
||||||
|
dtype: torch.dtype | int,
|
||||||
|
shape: Sequence[int | torch.SymInt],
|
||||||
|
version: int | None = None,
|
||||||
|
metadata_props: dict[str, str] | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Create a symbolic FX operator to represent an arbitrary ONNX operator.
|
||||||
|
|
||||||
|
This function is used to create a symbolic operator with a single output.
|
||||||
|
To create an operator with multiple outputs, use :func:`symbolic_multi_out`.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
class CustomOp(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x,),
|
||||||
|
dict(attr_key="attr_value"),
|
||||||
|
dtype=x.dtype,
|
||||||
|
shape=x.shape,
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
||||||
|
# The output tensor will have the specified dtype and shape.
|
||||||
|
|
||||||
|
|
||||||
|
# You may then export this model to ONNX using torch.onnx.export.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_op: The domain and operator name, separated by "::". For example,
|
||||||
|
"custom_domain::CustomOp".
|
||||||
|
inputs: The input tensors to the operator.
|
||||||
|
attrs: The attributes of the operator. The keys are attribute names and
|
||||||
|
the values are attribute values. Valid attribute types are int, float,
|
||||||
|
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
||||||
|
are unsupported.
|
||||||
|
dtype: The data type of the output tensor.This can be either a torch.dtype
|
||||||
|
or an integer representing the ONNX data type.
|
||||||
|
shape: The shape of the output tensor. This can be a list of integers or
|
||||||
|
SymInt values.
|
||||||
|
version: The version of the opset used for the operator.
|
||||||
|
metadata_props: Metadata properties for the ONNX node.
|
||||||
|
This is a dictionary of str-str pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The output tensor of the operator.
|
||||||
|
"""
|
||||||
|
if not isinstance(dtype, int):
|
||||||
|
torch._check(
|
||||||
|
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE, lambda: f"Unsupported dtype: {dtype}"
|
||||||
|
)
|
||||||
|
dtype = _TORCH_DTYPE_TO_ONNX_DTYPE[dtype]
|
||||||
|
domain, op_type = _parse_domain_op_type(domain_op)
|
||||||
|
if attrs is None:
|
||||||
|
attrs = {}
|
||||||
|
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
||||||
|
# TODO: Parse domain
|
||||||
|
return _symbolic_impl._symbolic(
|
||||||
|
inputs,
|
||||||
|
op_type,
|
||||||
|
dtype,
|
||||||
|
shape=shape,
|
||||||
|
attr_keys=encoded_attrs.attr_keys,
|
||||||
|
attr_types=encoded_attrs.attr_types,
|
||||||
|
attr_pos=encoded_attrs.attr_pos,
|
||||||
|
attr_ints=encoded_attrs.attr_ints,
|
||||||
|
attr_floats=encoded_attrs.attr_floats,
|
||||||
|
attr_strs=encoded_attrs.attr_strs,
|
||||||
|
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
||||||
|
metadata_props_values=metadata_props.values() if metadata_props else [],
|
||||||
|
domain=domain,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def symbolic_multi_out(
|
||||||
|
domain_op: str,
|
||||||
|
/,
|
||||||
|
inputs: Sequence[torch.Tensor],
|
||||||
|
attrs: dict[
|
||||||
|
str,
|
||||||
|
int
|
||||||
|
| float
|
||||||
|
| str
|
||||||
|
| bool
|
||||||
|
| Sequence[int]
|
||||||
|
| Sequence[float]
|
||||||
|
| Sequence[str]
|
||||||
|
| Sequence[bool],
|
||||||
|
]
|
||||||
|
| None = None,
|
||||||
|
*,
|
||||||
|
dtypes: Sequence[torch.dtype | int],
|
||||||
|
shapes: Sequence[Sequence[int | torch.SymInt]],
|
||||||
|
version: int | None = None,
|
||||||
|
metadata_props: dict[str, str] | None = None,
|
||||||
|
) -> Sequence[torch.Tensor]:
|
||||||
|
"""Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
class CustomOp(torch.nn.Module):
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return torch.onnx.ops.symbolic(
|
||||||
|
"custom_domain::CustomOp",
|
||||||
|
(x,),
|
||||||
|
dict(attr_key="attr_value"),
|
||||||
|
dtypes=(x.dtype, torch.float32),
|
||||||
|
shapes=(x.shape, [1, 2, 3]),
|
||||||
|
version=1,
|
||||||
|
)
|
||||||
|
# This will create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain.
|
||||||
|
# The output tensor will have the specified dtype and shape.
|
||||||
|
|
||||||
|
|
||||||
|
# You may then export this model to ONNX using torch.onnx.export.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
domain_op: The domain and operator name, separated by "::". For example,
|
||||||
|
"custom_domain::CustomOp".
|
||||||
|
inputs: The input tensors to the operator.
|
||||||
|
attrs: The attributes of the operator. The keys are attribute names and
|
||||||
|
the values are attribute values. Valid attribute types are int, float,
|
||||||
|
str, bool, and lists of int, float, str, and bool. Tensor attributes
|
||||||
|
are unsupported.
|
||||||
|
dtypes: The data types of the output tensors. This can be a list of
|
||||||
|
torch.dtype or integers representing the ONNX data types. The length
|
||||||
|
of this list must be the number of outputs.
|
||||||
|
shapes: The shapes of the output tensors. This can be a list of lists of
|
||||||
|
integers or SymInt values. The length of this list must be the number of outputs.
|
||||||
|
version: The version of the opset used for the operator.
|
||||||
|
metadata_props: Metadata properties for the ONNX node.
|
||||||
|
This is a dictionary of str-str pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of output tensors of the operator.
|
||||||
|
"""
|
||||||
|
torch._check(
|
||||||
|
len(shapes) == len(dtypes),
|
||||||
|
lambda: f"Number of shapes ({len(shapes)}) must match number of dtypes ({len(dtypes)})",
|
||||||
|
)
|
||||||
|
onnx_dtypes = []
|
||||||
|
for dtype in dtypes:
|
||||||
|
if not isinstance(dtype, int):
|
||||||
|
torch._check(
|
||||||
|
dtype in _TORCH_DTYPE_TO_ONNX_DTYPE,
|
||||||
|
lambda: f"Unsupported dtype: {dtype}",
|
||||||
|
)
|
||||||
|
onnx_dtypes.append(_TORCH_DTYPE_TO_ONNX_DTYPE[dtype])
|
||||||
|
else:
|
||||||
|
onnx_dtypes.append(dtype)
|
||||||
|
domain, op_type = _parse_domain_op_type(domain_op)
|
||||||
|
if attrs is None:
|
||||||
|
attrs = {}
|
||||||
|
encoded_attrs = _symbolic_impl.EncodedAttrs.from_dict(attrs)
|
||||||
|
# Use the size of dtypes to determine the number of outputs
|
||||||
|
return _symbolic_impl._symbolic_multi_out(
|
||||||
|
inputs,
|
||||||
|
op_type,
|
||||||
|
onnx_dtypes,
|
||||||
|
shapes=shapes,
|
||||||
|
attr_keys=encoded_attrs.attr_keys,
|
||||||
|
attr_types=encoded_attrs.attr_types,
|
||||||
|
attr_pos=encoded_attrs.attr_pos,
|
||||||
|
attr_ints=encoded_attrs.attr_ints,
|
||||||
|
attr_floats=encoded_attrs.attr_floats,
|
||||||
|
attr_strs=encoded_attrs.attr_strs,
|
||||||
|
metadata_props_keys=metadata_props.keys() if metadata_props else [],
|
||||||
|
metadata_props_values=metadata_props.values() if metadata_props else [],
|
||||||
|
domain=domain,
|
||||||
|
version=version,
|
||||||
|
)
|
||||||
330
torch/onnx/ops/_symbolic_impl.py
Normal file
330
torch/onnx/ops/_symbolic_impl.py
Normal file
|
|
@ -0,0 +1,330 @@
|
||||||
|
"""Implementation of symbolic FX ops to represent arbitrary ONNX ops.
|
||||||
|
|
||||||
|
This module provides a way to create symbolic FX operators that can represent
|
||||||
|
arbitrary ONNX operators.
|
||||||
|
|
||||||
|
The operators are called "symbolic" because they don't do any actual computation
|
||||||
|
but instead serve as placeholders in the computation graph.
|
||||||
|
|
||||||
|
Each implementation contains two parts: A "real" implementation that produce all
|
||||||
|
zeros based on the input shape and dtype, and a "fake" implementation that does more
|
||||||
|
or less the same thing but is required by the `torch.library.custom_op` interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
_ONNX_DTYPE_TO_TORCH_DTYPE: dict[int, torch.dtype] = {
|
||||||
|
1: torch.float32, # FLOAT
|
||||||
|
2: torch.uint8, # UINT8
|
||||||
|
3: torch.int8, # INT8
|
||||||
|
4: torch.uint16, # UINT16
|
||||||
|
5: torch.int16, # INT16
|
||||||
|
6: torch.int32, # INT32
|
||||||
|
7: torch.int64, # INT64
|
||||||
|
9: torch.bool, # BOOL
|
||||||
|
10: torch.float16, # FLOAT16
|
||||||
|
11: torch.double, # DOUBLE
|
||||||
|
12: torch.uint32, # UINT32
|
||||||
|
13: torch.uint64, # UINT64
|
||||||
|
14: torch.complex64, # COMPLEX64
|
||||||
|
15: torch.complex128, # COMPLEX128
|
||||||
|
16: torch.bfloat16, # BFLOAT16
|
||||||
|
17: torch.float8_e4m3fn, # FLOAT8E4M3FN
|
||||||
|
18: torch.float8_e4m3fnuz, # FLOAT8E4M3FNUZ
|
||||||
|
19: torch.float8_e5m2, # FLOAT8E5M2
|
||||||
|
20: torch.float8_e5m2fnuz, # FLOAT8E5M2FNUZ
|
||||||
|
21: torch.uint8, # UINT4
|
||||||
|
22: torch.uint8, # INT4
|
||||||
|
23: torch.uint8, # FLOAT4E2M1
|
||||||
|
}
|
||||||
|
|
||||||
|
_INT_TYPE = "i"
|
||||||
|
_FLOAT_TYPE = "f"
|
||||||
|
_STRING_TYPE = "s"
|
||||||
|
_INT_SEQ_TYPE = "is"
|
||||||
|
_FLOAT_SEQ_TYPE = "fs"
|
||||||
|
_STRING_SEQ_TYPE = "ss"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class EncodedAttrs:
|
||||||
|
"""Class to encode attributes from dictionary into lists of FX compatible attributes.
|
||||||
|
|
||||||
|
Since FX does not support dictionaries, we need to encode the attributes into
|
||||||
|
lists. This class provides a way to encode and decode the attributes.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
attr_keys: List of attribute keys.
|
||||||
|
attr_types: List of attribute types. Values can be "i" (int), "f" (float),
|
||||||
|
"s" (string), "is" (int sequence), "fs" (float sequence), or "ss" (string sequence).
|
||||||
|
attr_pos: List of tuples representing the start and end positions of each
|
||||||
|
attribute in the corresponding list.
|
||||||
|
attr_ints: List of integer attributes.
|
||||||
|
attr_floats: List of float attributes.
|
||||||
|
attr_strs: List of string attributes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attr_keys: list[str]
|
||||||
|
attr_types: list[str]
|
||||||
|
attr_pos: list[tuple[int, int]]
|
||||||
|
attr_ints: list[int]
|
||||||
|
attr_floats: list[float]
|
||||||
|
attr_strs: list[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(
|
||||||
|
cls,
|
||||||
|
attrs: dict[
|
||||||
|
str,
|
||||||
|
Union[
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
str,
|
||||||
|
bool,
|
||||||
|
Sequence[int],
|
||||||
|
Sequence[float],
|
||||||
|
Sequence[str],
|
||||||
|
Sequence[bool],
|
||||||
|
],
|
||||||
|
],
|
||||||
|
) -> "EncodedAttrs":
|
||||||
|
encoded = cls(
|
||||||
|
attr_keys=[],
|
||||||
|
attr_types=[],
|
||||||
|
attr_pos=[],
|
||||||
|
attr_ints=[],
|
||||||
|
attr_floats=[],
|
||||||
|
attr_strs=[],
|
||||||
|
)
|
||||||
|
for i, (k, v) in enumerate(attrs.items()):
|
||||||
|
encoded.attr_keys.append(k)
|
||||||
|
if isinstance(v, int):
|
||||||
|
start_pos = len(encoded.attr_ints)
|
||||||
|
encoded.attr_ints.append(v)
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||||
|
encoded.attr_types.append(_INT_TYPE)
|
||||||
|
elif isinstance(v, float):
|
||||||
|
start_pos = len(encoded.attr_floats)
|
||||||
|
encoded.attr_floats.append(v)
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||||
|
encoded.attr_types.append(_FLOAT_TYPE)
|
||||||
|
elif isinstance(v, str):
|
||||||
|
start_pos = len(encoded.attr_strs)
|
||||||
|
encoded.attr_strs.append(v)
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + 1))
|
||||||
|
encoded.attr_types.append(_STRING_TYPE)
|
||||||
|
elif isinstance(v, Sequence):
|
||||||
|
if len(v) == 0:
|
||||||
|
raise ValueError(f"Empty sequence for attribute {k}")
|
||||||
|
if any(isinstance(elem, float) for elem in v):
|
||||||
|
start_pos = len(encoded.attr_floats)
|
||||||
|
encoded.attr_floats.extend([float(elem) for elem in v])
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||||
|
encoded.attr_types.append(_FLOAT_SEQ_TYPE)
|
||||||
|
elif isinstance(v[0], int):
|
||||||
|
start_pos = len(encoded.attr_ints)
|
||||||
|
encoded.attr_ints.extend([int(elem) for elem in v])
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||||
|
encoded.attr_types.append(_INT_SEQ_TYPE)
|
||||||
|
elif isinstance(v[0], str):
|
||||||
|
start_pos = len(encoded.attr_strs)
|
||||||
|
encoded.attr_strs.extend([str(elem) for elem in v])
|
||||||
|
encoded.attr_pos.append((start_pos, start_pos + len(v)))
|
||||||
|
encoded.attr_types.append(_STRING_SEQ_TYPE)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported sequence type for attribute {k}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported attribute type for {k}: {type(v)}")
|
||||||
|
assert len(encoded.attr_keys) == len(encoded.attr_types), (
|
||||||
|
f"Mismatch between number of attribute keys and types: {len(encoded.attr_keys)} != {len(encoded.attr_types)}"
|
||||||
|
)
|
||||||
|
assert len(encoded.attr_keys) == len(encoded.attr_pos), (
|
||||||
|
f"Mismatch between number of attribute keys and positions: {len(encoded.attr_keys)} != {len(encoded.attr_pos)}"
|
||||||
|
)
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
def to_dict(
|
||||||
|
self,
|
||||||
|
) -> dict[
|
||||||
|
str,
|
||||||
|
Union[
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
str,
|
||||||
|
list[int],
|
||||||
|
list[float],
|
||||||
|
list[str],
|
||||||
|
],
|
||||||
|
]:
|
||||||
|
"""Convert the encoded attributes back to a dictionary for creating an ONNX node."""
|
||||||
|
attrs: dict[
|
||||||
|
str,
|
||||||
|
Union[
|
||||||
|
int,
|
||||||
|
float,
|
||||||
|
str,
|
||||||
|
list[int],
|
||||||
|
list[float],
|
||||||
|
list[str],
|
||||||
|
],
|
||||||
|
] = {}
|
||||||
|
for i, key in enumerate(self.attr_keys):
|
||||||
|
attr_type = self.attr_types[i]
|
||||||
|
if attr_type == _INT_TYPE:
|
||||||
|
attrs[key] = self.attr_ints[self.attr_pos[i][0]]
|
||||||
|
elif attr_type == _FLOAT_TYPE:
|
||||||
|
attrs[key] = self.attr_floats[self.attr_pos[i][0]]
|
||||||
|
elif attr_type == _STRING_TYPE:
|
||||||
|
attrs[key] = self.attr_strs[self.attr_pos[i][0]]
|
||||||
|
elif attr_type == _FLOAT_SEQ_TYPE:
|
||||||
|
attrs[key] = self.attr_floats[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||||
|
elif attr_type == _INT_SEQ_TYPE:
|
||||||
|
attrs[key] = self.attr_ints[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||||
|
elif attr_type == _STRING_SEQ_TYPE:
|
||||||
|
attrs[key] = self.attr_strs[self.attr_pos[i][0] : self.attr_pos[i][1]]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported attribute type: {attr_type}")
|
||||||
|
return attrs
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"onnx_symbolic::_symbolic",
|
||||||
|
mutates_args=(),
|
||||||
|
schema=(
|
||||||
|
"(Tensor?[] inputs, str op_type, int onnx_dtype, *,"
|
||||||
|
" SymInt[] shape, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
||||||
|
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
||||||
|
" str[] metadata_props_values, str domain='', int? version=None"
|
||||||
|
") -> Tensor"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def _symbolic(
|
||||||
|
inputs: Sequence[Optional[torch.Tensor]],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtype: int,
|
||||||
|
*,
|
||||||
|
shape: Sequence[Union[int, torch.SymInt]],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
torch._check(
|
||||||
|
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||||
|
)
|
||||||
|
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||||
|
|
||||||
|
|
||||||
|
@_symbolic.register_fake
|
||||||
|
def _(
|
||||||
|
inputs: Sequence[torch.Tensor],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtype: int,
|
||||||
|
*,
|
||||||
|
shape: Sequence[Union[int, torch.SymInt]],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
torch._check(
|
||||||
|
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||||
|
)
|
||||||
|
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||||
|
# out how it can handle empty shapes
|
||||||
|
return torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype])
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op(
|
||||||
|
"onnx_symbolic::_symbolic_multi_out",
|
||||||
|
mutates_args=(),
|
||||||
|
schema=(
|
||||||
|
"(Tensor?[] inputs, str op_type, int[] onnx_dtypes, *,"
|
||||||
|
" SymInt[][] shapes, str[] attr_keys, str[] attr_types, int[][] attr_pos,"
|
||||||
|
" int[] attr_ints, float[] attr_floats, str[] attr_strs, str[] metadata_props_keys,"
|
||||||
|
" str[] metadata_props_values, str domain='', int? version=None"
|
||||||
|
") -> Tensor[]"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def _symbolic_multi_out(
|
||||||
|
inputs: Sequence[Optional[torch.Tensor]],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtypes: Sequence[int],
|
||||||
|
*,
|
||||||
|
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: Optional[int] = None,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
outputs = []
|
||||||
|
torch._check(
|
||||||
|
len(shapes) == len(onnx_dtypes),
|
||||||
|
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
||||||
|
)
|
||||||
|
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||||
|
torch._check(
|
||||||
|
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||||
|
)
|
||||||
|
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
@_symbolic_multi_out.register_fake
|
||||||
|
def _(
|
||||||
|
inputs: Sequence[torch.Tensor],
|
||||||
|
op_type: str,
|
||||||
|
onnx_dtypes: Sequence[int],
|
||||||
|
*,
|
||||||
|
shapes: Sequence[Sequence[Union[int, torch.SymInt]]],
|
||||||
|
attr_keys: Sequence[str],
|
||||||
|
attr_types: Sequence[str],
|
||||||
|
attr_pos: Sequence[tuple[int, int]],
|
||||||
|
attr_ints: Sequence[int],
|
||||||
|
attr_floats: Sequence[float],
|
||||||
|
attr_strs: Sequence[str],
|
||||||
|
metadata_props_keys: Sequence[str] = (),
|
||||||
|
metadata_props_values: Sequence[str] = (),
|
||||||
|
domain: str = "",
|
||||||
|
version: Optional[int] = None,
|
||||||
|
) -> list[torch.Tensor]:
|
||||||
|
outputs = []
|
||||||
|
torch._check(
|
||||||
|
len(shapes) == len(onnx_dtypes),
|
||||||
|
lambda: f"Number of shapes ({len(shapes)}) must match number of ONNX dtypes ({len(onnx_dtypes)})",
|
||||||
|
)
|
||||||
|
for shape, onnx_dtype in zip(shapes, onnx_dtypes):
|
||||||
|
torch._check(
|
||||||
|
onnx_dtype in _ONNX_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
lambda: f"{onnx_dtype} is invalid as an ONNX data type. Valid values are {list(_ONNX_DTYPE_TO_TORCH_DTYPE.keys())}",
|
||||||
|
)
|
||||||
|
# NOTE(justinchuby): Use zeros instead of torch.empty because I haven't figured
|
||||||
|
# out how it can handle empty shapes
|
||||||
|
outputs.append(torch.zeros(shape, dtype=_ONNX_DTYPE_TO_TORCH_DTYPE[onnx_dtype]))
|
||||||
|
return outputs
|
||||||
Loading…
Reference in New Issue
Block a user