[cuDNN] Allow cudnn attention or flash attention in test_export.py regex (#154458)

Analogous to #153272

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154458
Approved by: https://github.com/drisspg
This commit is contained in:
eqy 2025-05-29 23:51:05 +00:00 committed by PyTorch MergeBot
parent dc0f09a478
commit 818f76a745

View File

@ -62,7 +62,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
SM90OrLater,
xfailIfDistributedNotSupported,
)
from torch.testing._internal.common_utils import (
@ -14406,17 +14405,21 @@ def forward(self, q, k, v):
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(q, k, v, 0.0, True, scale = 0.125); q = k = v = None
getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None
return (getitem,)"""
# TODO(eqy): this needs to stay in sync with default SDPA priority order
if (False and SM90OrLater) and not torch.version.hip:
try:
self.assertExpectedInline(
ep.graph_module.code.strip(),
code_str,
)
except AssertionError:
code_str = """\
def forward(self, q, k, v):
_scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(q, k, v, None, False, 0.0, True); q = k = v = None
getitem = _scaled_dot_product_cudnn_attention[0]; _scaled_dot_product_cudnn_attention = None
return (getitem,)"""
self.assertExpectedInline(
ep.graph_module.code.strip(),
code_str,
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
code_str,
)
def test_int_list_output(self):
class M(torch.nn.Module):