mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Fix conversion of attention - 4D (#157130)
Fixes a wrong conversion to onnx while investigation #149662. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157130 Approved by: https://github.com/gramalingam, https://github.com/justinchuby, https://github.com/titaiwangms Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
This commit is contained in:
parent
d5d14ee823
commit
0105cd89ab
|
|
@ -5,9 +5,13 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import onnx.reference as onnx_ref
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
import pytest
|
import pytest
|
||||||
import transformers
|
import transformers
|
||||||
from onnxscript import ir
|
from onnxscript import ir
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.onnx._internal.exporter import _testing as onnx_testing
|
from torch.onnx._internal.exporter import _testing as onnx_testing
|
||||||
|
|
@ -15,6 +19,10 @@ from torch.testing._internal import common_utils
|
||||||
from torch.utils import _pytree as torch_pytree
|
from torch.utils import _pytree as torch_pytree
|
||||||
|
|
||||||
|
|
||||||
|
def has_onnxruntime_opset_23() -> bool:
|
||||||
|
return version.parse(onnxruntime.__version__) >= version.parse("1.22")
|
||||||
|
|
||||||
|
|
||||||
class _WithExport:
|
class _WithExport:
|
||||||
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
def export(self, model, args=(), kwargs=None, **options) -> torch.onnx.ONNXProgram:
|
||||||
onnx_program = torch.onnx.export(
|
onnx_program = torch.onnx.export(
|
||||||
|
|
@ -736,11 +744,17 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
||||||
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
query = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||||
|
expected = Model()(query, key, value)
|
||||||
|
|
||||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
||||||
self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph])
|
self.assertIn("Attention", [node.op_type for node in onnx_program.model.graph])
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Expected to fail until opset 23 is supported by ORT.")
|
ref = onnx_ref.ReferenceEvaluator(onnx_program.model_proto)
|
||||||
|
got = ref.run(
|
||||||
|
None, dict(query=query.numpy(), key=key.numpy(), value=value.numpy())
|
||||||
|
)[0]
|
||||||
|
torch.testing.assert_close(torch.from_numpy(got), expected, atol=1e-2, rtol=1)
|
||||||
|
|
||||||
def test_graph_accuracy_attention_opset_23(self):
|
def test_graph_accuracy_attention_opset_23(self):
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def forward(self, query, key, value):
|
def forward(self, query, key, value):
|
||||||
|
|
@ -752,8 +766,13 @@ class DynamoExporterNewOpsetsTest(common_utils.TestCase, _WithExport):
|
||||||
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
key = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||||
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
value = torch.rand(32, 8, 128, 64, dtype=torch.float16)
|
||||||
|
|
||||||
onnx_program = self.export(Model(), (query, key, value), opset_version=23)
|
onnx_program = self.export(
|
||||||
onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1)
|
Model(), (query, key, value), opset_version=23, optimize=True
|
||||||
|
)
|
||||||
|
self.assertEqual(["Attention"], [n.op_type for n in onnx_program.model.graph])
|
||||||
|
# onnxruntime inlines any op defined as a function and without any implemented kernel
|
||||||
|
if has_onnxruntime_opset_23():
|
||||||
|
onnx_testing.assert_onnx_program(onnx_program, atol=1e-2, rtol=1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -119,19 +119,19 @@ def aten_scaled_dot_product_attention_23(
|
||||||
"SDPA (MHA) requires q_num_heads = kv_num_heads"
|
"SDPA (MHA) requires q_num_heads = kv_num_heads"
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: There was extended discussion on whether the num_heads attributes (q_num_heads/kv_num_heads)
|
# NOTE: num_heads attributes (q_num_heads/kv_num_heads) should not be specified for 4D.
|
||||||
# should be set as ONNX attributes or inferred from the tensor shape. In ONNX, num_heads is needed
|
# They are not populated with 4D inputs because this information directy comes from input shapes:
|
||||||
# for 3D attention inputs (shape: [B, S, N*H]), but not for 4D ([B, N, S, H]), which is the only
|
# `q_num_heads=query.shape[1]` and `kv_num_heads=key.shape[1]`.
|
||||||
# input accepted by this exporter. Thus, the attribute is not strictly necessary here, but adding it
|
# This dimension is usually static but it could not be dynamic if also given as an attribute.
|
||||||
# may ease future optimization or conversion to 3D formats (e.g., GQA ops)
|
# num_heads attributes are needed for 3D attention inputs:
|
||||||
|
# (shape: [B, S, N*H]), 4D shape is ([B, N, S, H]).
|
||||||
|
|
||||||
Y, _, _, _ = op23.Attention(
|
Y, _, _, _ = op23.Attention(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
q_num_heads=query.shape[-3],
|
|
||||||
kv_num_heads=key.shape[-3],
|
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
return Y
|
return Y
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user