pytorch/test/onnx/ops/test_ops.py
Justin Chu c5701d0ab5 [ONNX] Create fake implementations for onnx ops; fix boolean mask in attention (#165780)
Previously we rely on the concreate implementation to generate fake implementation. This makes the fake implementation overly complicated and breaks in some cases when there are dynamic shapes.

This PR updates onnx op registration to instead take a dedicated fake implementation.

**Also fixed: When boolean mask is supplied to torch sdpa, it was previously taken the negation, which is incorrect.**

Fix https://github.com/pytorch/pytorch/issues/164909 Also taken changes from https://github.com/pytorch/pytorch/pull/156635

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165780
Approved by: https://github.com/titaiwangms
2025-10-29 04:51:49 +00:00

1482 lines
57 KiB
Python

# Owner(s): ["module: onnx"]
"""Test torch.onnx.ops."""
from __future__ import annotations
import onnx_ir.passes.common as common_passes
from onnxscript import ir
import torch
from torch.onnx._internal.exporter import _testing as onnx_testing
from torch.onnx.ops import _impl, _symbolic_impl
from torch.testing._internal import common_utils
class SchemaTest(common_utils.TestCase):
def test_symbolic_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic,
([torch.tensor(1)], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="custom_domain",
version=42,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic,
([], "CustomOp", 1),
dict(
shape=[
1,
],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
def test_symbolic_multi_out_has_correct_schema(self):
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([torch.tensor(1)], "CustomMultiOutOp", [1, 2, 10]),
dict(
shapes=[[1, 2], [42], []],
attr_keys=["key"],
attr_types=["i"],
attr_pos=[(0, 1)],
attr_ints=[1],
attr_floats=[1.0],
attr_strs=["attr"],
metadata_props_keys=["meta_key"],
metadata_props_values=["meta_value"],
domain="",
version=1,
),
)
# Empty inputs
torch.library.opcheck(
_symbolic_impl._symbolic_multi_out,
([], "CustomMultiOutOp", []),
dict(
shapes=[],
attr_keys=[],
attr_types=[],
attr_pos=[],
attr_ints=[],
attr_floats=[],
attr_strs=[],
metadata_props_keys=[],
metadata_props_values=[],
),
)
class SymbolicOpsTest(common_utils.TestCase):
def test_symbolic_accepts_valid_inputs(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=torch.float32,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(output.shape, torch.Size([1, 2, 3]))
self.assertEqual(output.dtype, torch.float32)
self.assertEqual(output.device, torch.device("cpu"))
def test_symbolic_accepts_valid_inputs_empty_shape(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=torch.float32,
shape=[],
)
self.assertEqual(output.shape, torch.Size([]))
def test_symbolic_accepts_valid_inputs_integer_types(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=1, # 1 is float32 in ONNX
shape=[42],
)
self.assertEqual(output.dtype, torch.float32)
def test_symbolic_accepts_valid_inputs_int4_type(self):
output = torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtype=22, # 22 is INT4 in ONNX
shape=[42],
)
# We use torch uint8 for int4
self.assertEqual(output.dtype, torch.uint8)
def test_symbolic_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x, None),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtype=x.dtype,
shape=[1, 2, 3],
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2, 3])
self.assertEqual(outputs[0].dtype, ir.DataType.INT64)
def test_symbolic_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic(
"custom_domain::CustomOp",
(x, y),
dtype=x.dtype,
shape=[*x.shape, *y.shape],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
def test_symbolic_multi_out_accepts_valid_inputs(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(
1, # 1 is float32 in ONNX
torch.int32,
torch.float8_e4m3fn,
),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
self.assertEqual(len(outputs), 3)
self.assertEqual(outputs[0].shape, torch.Size([1, 2]))
self.assertEqual(outputs[0].dtype, torch.float32)
self.assertEqual(outputs[1].shape, torch.Size([42]))
self.assertEqual(outputs[1].dtype, torch.int32)
self.assertEqual(outputs[2].shape, torch.Size([]))
self.assertEqual(outputs[2].dtype, torch.float8_e4m3fn)
self.assertEqual(outputs[0].device, torch.device("cpu"))
self.assertEqual(outputs[1].device, torch.device("cpu"))
self.assertEqual(outputs[2].device, torch.device("cpu"))
def test_symbolic_multi_out_accepts_valid_inputs_empty_shape(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(torch.float32,),
shapes=[[]],
)
self.assertEqual(outputs[0].shape, torch.Size([]))
def test_symbolic_multi_out_accepts_valid_inputs_integer_types(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(1,), # 1 is float32 in ONNX
shapes=[[42]],
)
self.assertEqual(outputs[0].dtype, torch.float32)
def test_symbolic_multi_out_accepts_valid_inputs_int4_type(self):
outputs = torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(torch.tensor(1),),
dtypes=(22,), # 22 is INT4 in ONNX
shapes=[[42]],
)
# We use torch uint8 for int4
self.assertEqual(outputs[0].dtype, torch.uint8)
def test_symbolic_multi_out_is_exportable(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x, None),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32, torch.float8_e4m3fn),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
onnx_program = torch.onnx.export(
Model(), (torch.tensor(1),), dynamo=True, verbose=False
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
attributes = node.attributes
self.assertEqual(
attributes,
dict(
int_key=ir.AttrInt64("int_key", 1),
float_key=ir.AttrFloat32("float_key", 1.0),
str_key=ir.AttrString("str_key", "attr"),
bool_key=ir.AttrInt64("bool_key", 1),
list_int_key=ir.AttrInt64s("list_int_key", [1, 2]),
list_float_key=ir.AttrFloat32s("list_float_key", [1.0, 2.0]),
list_str_key=ir.AttrStrings("list_str_key", ["attr1", "attr2"]),
list_bool_key=ir.AttrInt64s("list_bool_key", [1, 0]),
),
)
self.assertEqual(node.metadata_props["meta_key"], "meta_value")
outputs = node.outputs
self.assertEqual(list(outputs[0].shape), [1, 2])
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT32)
self.assertEqual(list(outputs[2].shape), [])
self.assertEqual(outputs[2].dtype, ir.DataType.FLOAT8E4M3FN)
def test_symbolic_multi_out_preserves_dynamic_shapes(self):
class Model(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomOp",
(x, y),
dtypes=(x.dtype, 22), # 22 is INT4
shapes=[[*x.shape, *y.shape], [42]],
version=1,
)
onnx_program = torch.onnx.export(
Model(),
(torch.zeros(2, 3), torch.zeros(1, 2)),
dynamic_shapes=({0: "batch"}, {1: "something_else"}),
dynamo=True,
verbose=False,
)
assert onnx_program is not None
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "CustomOp")
self.assertEqual(node.domain, "custom_domain")
inputs = onnx_program.model.graph.inputs
self.assertEqual(str(inputs[0].shape[0]), "batch")
self.assertEqual(inputs[0].shape[1], 3)
self.assertEqual(inputs[1].shape[0], 1)
self.assertEqual(str(inputs[1].shape[1]), "something_else")
outputs = node.outputs
self.assertEqual(str(outputs[0].shape[0]), "batch")
self.assertEqual(outputs[0].shape[1], 3)
self.assertEqual(outputs[0].shape[2], 1)
self.assertEqual(str(outputs[0].shape[3]), "something_else")
self.assertEqual(outputs[0].dtype, ir.DataType.FLOAT)
self.assertEqual(list(outputs[1].shape), [42])
self.assertEqual(outputs[1].dtype, ir.DataType.INT4)
def test_symbolic_multi_out_raises_when_dtypes_and_shapes_differ(self):
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32, torch.int32),
shapes=([1, 2], [42], []),
version=1,
metadata_props={"meta_key": "meta_value"},
)
with self.assertRaises(RuntimeError):
torch.onnx.ops.symbolic_multi_out(
"custom_domain::CustomMultiOutOp",
(torch.tensor(1),),
dict(
int_key=1,
float_key=1.0,
str_key="attr",
bool_key=True,
list_int_key=[1, 2],
list_float_key=[1.0, 2.0],
list_str_key=["attr1", "attr2"],
list_bool_key=[True, False],
),
dtypes=(torch.float32,),
shapes=([1, 2], [42]),
version=1,
metadata_props={"meta_key": "meta_value"},
)
@common_utils.instantiate_parametrized_tests
class NativeOnnxOpsTest(common_utils.TestCase):
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
onnx_program = torch.onnx.export(
model,
args,
kwargs=kwargs,
dynamo=True,
fallback=False,
verbose=False,
**options,
)
assert onnx_program is not None
common_passes.CheckerPass()(onnx_program.model)
return onnx_program
def test_onnx_ops_can_be_decomposed_to_aten(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
interleaved=True,
)
model = Model()
ep = torch.export.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
self.assertIn(
"onnx.RotaryEmbedding.opset23",
[str(node.target) for node in ep.graph.nodes],
)
# The program can be decomposed into aten ops so it is fully compatible with the PyTorch ecosystem
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
self.assertNotIn(
"onnx.RotaryEmbedding.opset23",
[str(node.target) for node in aten_decomped.graph.nodes],
)
torch.testing.assert_close(
aten_decomped.module()(
input_data, cos_cache_data, sin_cache_data, position_ids_data
),
model(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding_opcheck(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
torch.library.opcheck(
_impl.rotary_embedding_23,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
def test_rotary_embedding(self):
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 4)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
# Eager mode is supported. Autograd is also supported so users can choose to use the op
# in development and production
result = torch.onnx.ops.rotary_embedding(
input_data, cos_cache_data, sin_cache_data, position_ids_data
)
self.assertEqual(result.shape, input_data.shape)
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
interleaved=True,
)
model = Model()
# Dynamic shapes are supported
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": None,
"sin_cache_data": None,
"position_ids_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = self.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
onnx_testing.assert_onnx_program(onnx_program)
def test_rotary_embedding_3d(self):
num_heads = 2
input_data = torch.rand(2, 3, 8)
sin_cache_data = torch.rand(2, 3, 2)
cos_cache_data = torch.rand(2, 3, 2)
class Model(torch.nn.Module):
def forward(self, input_data, cos_cache_data, sin_cache_data):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
num_heads=num_heads,
)
model = Model()
# Dynamic shapes are supported
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": {0: torch.export.Dim.DYNAMIC},
"sin_cache_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = self.export(
model,
(input_data, cos_cache_data, sin_cache_data),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("RotaryEmbedding", onnx_program.model.graph.node(0).op_type)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_without_past_kv_caches(self):
"""Test basic attention functionality."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test eager mode
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_3d_inputs(self):
"""Test attention with 3D inputs (requires num_heads parameters)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(q_num_heads=q_num_heads, kv_num_heads=kv_num_heads),
)
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
# Output should be reshaped back to 3D
self.assertEqual(output.shape, (batch_size, q_seq_len, q_num_heads * head_size))
self.assertEqual(
present_key.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
self.assertEqual(
present_value.shape, (batch_size, kv_num_heads, kv_seq_len, head_size)
)
@common_utils.parametrize(
"name, kv_num_heads",
[
("group_query_attention", 4),
("multi_query_attention", 1),
],
)
def test_attention_kv_num_heads(self, name: str, kv_num_heads: int):
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads = 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
self.assertEqual(present_key.shape, K.shape)
self.assertEqual(present_value.shape, V.shape)
torch.testing.assert_close(output, expected)
def test_attention_mqa(self):
"""Test Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA: kv_num_heads = 1
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V))
output, present_key, present_value, qk_output = torch.onnx.ops.attention(
Q, K, V
)
expected = torch.nn.functional.scaled_dot_product_attention(
Q, K, V, None, enable_gqa=True
)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
torch.testing.assert_close(output, expected)
def test_attention_with_2d_mask(self):
"""Test attention with 2D attention mask (q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_4d_mask(self):
"""Test attention with 4D attention mask (batch_size, num_heads, q_seq_len, kv_seq_len)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test with boolean mask
bool_mask = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=bool_mask))
output_bool, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=bool_mask)
# Test with float mask
float_mask = torch.randn(batch_size, q_num_heads, q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=float_mask))
output_float, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=float_mask)
self.assertEqual(
output_bool.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_float.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_with_zero_float_mask(self):
"""Test attention with zero float mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
zero_mask = torch.zeros(q_seq_len, kv_seq_len)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=zero_mask))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=zero_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_causal_mask_pattern(self):
"""Test attention with lower triangular causal mask pattern."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Create a lower triangular causal mask
causal_mask = torch.tril(torch.ones(q_seq_len, kv_seq_len, dtype=torch.bool))
torch.library.opcheck(
_impl.attention_23, (Q, K, V), dict(attn_mask=causal_mask)
)
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=causal_mask)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_gqa_and_mask(self):
"""Test attention with GQA and different mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask with GQA
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_2d))
output_2d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_2d)
# Test 4D mask with GQA (note: using q_num_heads for mask heads)
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(attn_mask=mask_4d))
output_4d, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask_4d)
self.assertEqual(
output_2d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
output_4d.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
def test_attention_causal(self):
"""Test causal attention."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 4 # Square for causal
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(is_causal=True))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, is_causal=True)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_with_past_kv(self):
"""Test attention with past key/value caches."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(past_key=past_key, past_value=past_value),
)
output, present_key, present_value, _ = torch.onnx.ops.attention(
Q, K, V, past_key=past_key, past_value=past_value
)
# Present key/value should include past + current
expected_total_seq_len = past_seq_len + kv_seq_len
self.assertEqual(
present_key.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
self.assertEqual(
present_value.shape,
(batch_size, kv_num_heads, expected_total_seq_len, head_size),
)
def test_attention_with_softcap(self):
"""Test attention with softcap."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(softcap=30.0))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, softcap=30.0)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_qk_output_modes(self):
"""Test different QK matmul output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
torch.library.opcheck(
_impl.attention_23,
(Q, K, V),
dict(qk_matmul_output_mode=mode),
)
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=mode
)
self.assertEqual(
output.shape, (batch_size, q_num_heads, q_seq_len, head_size)
)
self.assertEqual(
qk_output.shape, (batch_size, q_num_heads, q_seq_len, kv_seq_len)
)
def test_attention_custom_scale(self):
"""Test attention with custom scale factor."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
custom_scale = 0.25
torch.library.opcheck(_impl.attention_23, (Q, K, V), dict(scale=custom_scale))
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, scale=custom_scale)
self.assertEqual(output.shape, (batch_size, q_num_heads, q_seq_len, head_size))
def test_attention_export(self):
"""Test that attention can be exported to ONNX."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_with_dynamic_shapes(self):
"""Test attention export with dynamic shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
attn_mask = torch.randint(
0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V, attn_mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=attn_mask)
return output
model = AttentionModel()
dynamic_shapes = {
"Q": {0: "batch", 2: "q_seq_len"},
"K": {0: "batch", 2: "kv_seq_len"},
"V": {0: "batch", 2: "kv_seq_len"},
"attn_mask": {0: "batch", 2: "q_seq_len", 3: "kv_seq_len"},
}
onnx_program = self.export(
model,
(Q, K, V, attn_mask),
dynamic_shapes=dynamic_shapes,
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
node = onnx_program.model.graph.node(0)
# Verify inputs
self.assertEqual(len(node.inputs), 4)
self.assertEqual(
node.inputs[0].shape, ["batch", q_num_heads, "q_seq_len", head_size]
)
self.assertEqual(
node.inputs[1].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
self.assertEqual(
node.inputs[2].shape, ["batch", kv_num_heads, "kv_seq_len", head_size]
)
# Verify default attributes (should be minimal)
self.assertEqual(len(node.attributes), 0)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_3d_export(self):
"""Test attention export with 3D inputs."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = AttentionModel()
onnx_program = self.export(
model,
(Q, K, V),
opset_version=23,
)
self.assertEqual(onnx_program.model.opset_imports[""], 23)
self.assertEqual("Attention", onnx_program.model.graph.node(0).op_type)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_decomposition(self):
"""Test that attention can be decomposed to aten ops."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionModel(torch.nn.Module):
def forward(self, Q, K, V):
output, present_key, present_value, qk_output = (
torch.onnx.ops.attention(Q, K, V)
)
return output
model = AttentionModel()
ep = torch.export.export(model, (Q, K, V))
self.assertIn(
"onnx.Attention.opset23",
[str(node.target) for node in ep.graph.nodes],
)
# The program can be decomposed into aten ops
aten_decomped = ep.run_decompositions(torch.onnx.ops.aten_decompositions())
self.assertNotIn(
"onnx.Attention.opset23",
[str(node.target) for node in aten_decomped.graph.nodes],
)
# Results should match
torch.testing.assert_close(
aten_decomped.module()(Q, K, V),
model(Q, K, V),
)
def test_attention_export_with_past_key_value(self):
"""Test export with past_key, past_value to ensure the optional input order is correct."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class Model(torch.nn.Module):
def forward(self, Q, K, V, past_key, past_value):
output, present_key, present_value, _ = torch.onnx.ops.attention(
Q,
K,
V,
past_key=past_key,
attn_mask=None,
# Switched argument order
past_value=past_value,
)
return output, present_key, present_value
model = Model()
onnx_program = self.export(
model, (Q, K, V, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertIsNone(node.inputs[3])
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_with_all_optional_inputs(self):
"""Test export with all optional inputs: mask, past_key, past_value."""
batch_size, q_seq_len, kv_seq_len, past_seq_len = 2, 4, 6, 3
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
attn_mask = torch.randint(
0, 2, (1, 1, q_seq_len, kv_seq_len + past_seq_len), dtype=torch.bool
)
past_key = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
past_value = torch.rand(batch_size, kv_num_heads, past_seq_len, head_size)
class FullAttentionModel(torch.nn.Module):
def forward(self, Q, K, V, attn_mask, past_key, past_value):
output, present_key, present_value, qk_matmul = (
torch.onnx.ops.attention(
Q,
K,
V,
attn_mask=attn_mask,
past_key=past_key,
past_value=past_value,
)
)
return output, present_key, present_value, qk_matmul
model = FullAttentionModel()
onnx_program = self.export(
model, (Q, K, V, attn_mask, past_key, past_value), opset_version=23
)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 6 inputs are present
self.assertEqual(
len(node.inputs), 6
) # Q, K, V, attn_mask, past_key, past_value
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
self.assertEqual(
node.inputs[3].shape, [1, 1, q_seq_len, kv_seq_len + past_seq_len]
)
self.assertEqual(
node.inputs[4].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
self.assertEqual(
node.inputs[5].shape, [batch_size, kv_num_heads, past_seq_len, head_size]
)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_3d_with_num_heads_attributes(self):
"""Test export with 3D inputs and explicit num_heads attributes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_seq_len, q_num_heads * head_size)
K = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
V = torch.rand(batch_size, kv_seq_len, kv_num_heads * head_size)
class Attention3DModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, q_num_heads=q_num_heads, kv_num_heads=kv_num_heads
)
return output
model = Attention3DModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify 3D input shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_seq_len, q_num_heads * head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_seq_len, kv_num_heads * head_size]
)
# Verify num_heads attributes are set
attrs = node.attributes
self.assertIn("q_num_heads", attrs)
self.assertIn("kv_num_heads", attrs)
self.assertEqual(attrs["q_num_heads"].value, q_num_heads)
self.assertEqual(attrs["kv_num_heads"].value, kv_num_heads)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_with_all_attributes(self):
"""Test export with all possible attributes set."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class FullAttributesModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q,
K,
V,
is_causal=True,
qk_matmul_output_mode=2,
scale=0.25,
softcap=30.0,
softmax_precision=1, # FLOAT
)
return output
model = FullAttributesModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all attributes are set correctly
attrs = node.attributes
self.assertIn("is_causal", attrs)
self.assertIn("qk_matmul_output_mode", attrs)
self.assertIn("scale", attrs)
self.assertIn("softcap", attrs)
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["is_causal"].value, 1) # True as int
self.assertEqual(attrs["qk_matmul_output_mode"].value, 2)
self.assertAlmostEqual(attrs["scale"].value, 0.25, places=6)
self.assertAlmostEqual(attrs["softcap"].value, 30.0, places=6)
self.assertEqual(attrs["softmax_precision"].value, 1)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_with_different_mask_shapes(self):
"""Test export with different attention mask shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
# Test 2D mask
mask_2d = torch.randint(0, 2, (q_seq_len, kv_seq_len), dtype=torch.bool)
class Mask2DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_2d = Mask2DModel()
onnx_program_2d = self.export(model_2d, (Q, K, V, mask_2d), opset_version=23)
node_2d = onnx_program_2d.model.graph.node(0)
self.assertEqual(node_2d.inputs[3].shape, [q_seq_len, kv_seq_len])
onnx_testing.assert_onnx_program(onnx_program_2d)
# Test 3D mask
mask_3d = torch.randint(
0, 2, (batch_size, 1, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask3DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_3d = Mask3DModel()
onnx_program_3d = self.export(model_3d, (Q, K, V, mask_3d), opset_version=23)
node_3d = onnx_program_3d.model.graph.node(0)
self.assertEqual(
node_3d.inputs[3].shape, [batch_size, 1, q_seq_len, kv_seq_len]
)
onnx_testing.assert_onnx_program(onnx_program_3d)
# Test 4D mask
mask_4d = torch.randint(
0, 2, (batch_size, q_num_heads, q_seq_len, kv_seq_len), dtype=torch.bool
)
class Mask4DModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model_4d = Mask4DModel()
onnx_program_4d = self.export(model_4d, (Q, K, V, mask_4d), opset_version=23)
node_4d = onnx_program_4d.model.graph.node(0)
self.assertEqual(
node_4d.inputs[3].shape, [batch_size, q_num_heads, q_seq_len, kv_seq_len]
)
onnx_testing.assert_onnx_program(onnx_program_4d)
def test_attention_export_with_float_mask(self):
"""Test export with float attention mask."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
float_mask = torch.randn(q_seq_len, kv_seq_len)
class FloatMaskModel(torch.nn.Module):
def forward(self, Q, K, V, mask):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V, attn_mask=mask)
return output
model = FloatMaskModel()
onnx_program = self.export(model, (Q, K, V, float_mask), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
self.assertEqual(node.inputs[3].shape, [q_seq_len, kv_seq_len])
# Verify the mask input has float dtype in the ONNX model
self.assertEqual(node.inputs[3].dtype, ir.DataType.FLOAT)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_qk_output_modes(self):
"""Test export with different QK output modes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
for mode in [0, 1, 2, 3]:
class QKOutputModel(torch.nn.Module):
def __init__(self, qk_mode):
super().__init__()
self.qk_mode = qk_mode
def forward(self, Q, K, V):
output, _, _, qk_output = torch.onnx.ops.attention(
Q, K, V, qk_matmul_output_mode=self.qk_mode
)
return output, qk_output
model = QKOutputModel(mode)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify qk_matmul_output_mode attribute
attrs = node.attributes
if mode != 0:
self.assertIn("qk_matmul_output_mode", attrs)
self.assertEqual(attrs["qk_matmul_output_mode"].value, mode)
# Verify 4 outputs (output, present_key, present_value, qk_output)
self.assertEqual(len(node.outputs), 4)
onnx_testing.assert_onnx_program(onnx_program)
def test_attention_export_mqa(self):
"""Test export with Multi-Query Attention (MQA)."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 1 # MQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class MQAModel(torch.nn.Module):
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(Q, K, V)
return output
model = MQAModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify MQA tensor shapes
self.assertEqual(
node.inputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
self.assertEqual(
node.inputs[1].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
) # kv_num_heads = 1
self.assertEqual(
node.inputs[2].shape, [batch_size, kv_num_heads, kv_seq_len, head_size]
)
onnx_testing.assert_onnx_program(onnx_program)
@common_utils.parametrize(
"precision_enum, precision_name",
[
(1, "FLOAT"),
(10, "FLOAT16"),
(11, "DOUBLE"),
(16, "BFLOAT16"),
],
)
def test_attention_export_with_softmax_precision(
self, precision_enum, precision_name: str
):
"""Test export with different softmax precision values."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 8
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class SoftmaxPrecisionModel(torch.nn.Module):
def __init__(self, precision):
super().__init__()
self.precision = precision
def forward(self, Q, K, V):
output, _, _, _ = torch.onnx.ops.attention(
Q, K, V, softmax_precision=self.precision
)
return output
model = SoftmaxPrecisionModel(precision_enum)
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify softmax_precision attribute
attrs = node.attributes
self.assertIn("softmax_precision", attrs)
self.assertEqual(attrs["softmax_precision"].value, precision_enum)
onnx_testing.assert_onnx_program(onnx_program, atol=2e-3, rtol=6e-3)
def test_attention_export_gqa(self):
"""Test export and verify output tensor shapes."""
batch_size, q_seq_len, kv_seq_len = 2, 4, 6
q_num_heads, kv_num_heads = 8, 4 # GQA
head_size = 64
Q = torch.rand(batch_size, q_num_heads, q_seq_len, head_size)
K = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
V = torch.rand(batch_size, kv_num_heads, kv_seq_len, head_size)
class AttentionOutputsModel(torch.nn.Module):
def forward(self, Q, K, V):
result, _, _, _ = torch.onnx.ops.attention(Q, K, V)
return result
model = AttentionOutputsModel()
onnx_program = self.export(model, (Q, K, V), opset_version=23)
node = onnx_program.model.graph.node(0)
self.assertEqual(node.op_type, "Attention")
# Verify all 4 outputs have correct shapes
outputs = node.outputs
self.assertEqual(len(outputs), 4)
# output: (batch_size, q_num_heads, q_seq_len, head_size)
self.assertEqual(
outputs[0].shape, [batch_size, q_num_heads, q_seq_len, head_size]
)
onnx_testing.assert_onnx_program(onnx_program)
if __name__ == "__main__":
common_utils.run_tests()