mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cuDNN] Allow cudnn attention or flash attention in test_export.py regex (#154458)
Analogous to #153272 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154458 Approved by: https://github.com/drisspg
This commit is contained in:
parent
dc0f09a478
commit
818f76a745
|
|
@ -62,7 +62,6 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
SM90OrLater,
|
||||
xfailIfDistributedNotSupported,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
|
|
@ -14406,17 +14405,21 @@ def forward(self, q, k, v):
|
|||
_scaled_dot_product_flash_attention = torch.ops.aten._scaled_dot_product_flash_attention.default(q, k, v, 0.0, True, scale = 0.125); q = k = v = None
|
||||
getitem = _scaled_dot_product_flash_attention[0]; _scaled_dot_product_flash_attention = None
|
||||
return (getitem,)"""
|
||||
# TODO(eqy): this needs to stay in sync with default SDPA priority order
|
||||
if (False and SM90OrLater) and not torch.version.hip:
|
||||
try:
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
code_str,
|
||||
)
|
||||
except AssertionError:
|
||||
code_str = """\
|
||||
def forward(self, q, k, v):
|
||||
_scaled_dot_product_cudnn_attention = torch.ops.aten._scaled_dot_product_cudnn_attention.default(q, k, v, None, False, 0.0, True); q = k = v = None
|
||||
getitem = _scaled_dot_product_cudnn_attention[0]; _scaled_dot_product_cudnn_attention = None
|
||||
return (getitem,)"""
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
code_str,
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
code_str,
|
||||
)
|
||||
|
||||
def test_int_list_output(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user