Mark FlexAttentionBackward as cacheable (#165996)

This probably should have been marked cacheable a long time ago, no reason that it isn't.

Test Plan:
New regional inductor tests for test_flex_attention now are serializable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165996
Approved by: https://github.com/oulgen, https://github.com/zou3519, https://github.com/drisspg
This commit is contained in:
James Wu 2025-10-25 20:30:43 -07:00 committed by PyTorch MergeBot
parent a60d9e1f6d
commit e4c01011c2
5 changed files with 11 additions and 18 deletions

View File

@ -245,8 +245,7 @@ due to:
Traceback (most recent call last):
File "test_logging.py", line N, in throw
raise AssertionError
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
torch._inductor.exc.InductorError: LoweringException: AssertionError:
target: aten.round.default
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))

View File

@ -184,12 +184,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
@requires_cuda_and_triton
@parametrize("serialize", [False, True])
def test_flex_attention(self, serialize):
if serialize:
# TODO: Fixed in next PR
raise unittest.SkipTest(
"FlexAttentionBackward isn't marked cacheable even though it is"
)
def _squared(score, b, h, m, n):
return score * score
@ -229,11 +223,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
@requires_cuda_and_triton
@parametrize("serialize", [False, True])
def test_selective_ac_flex(self, serialize):
if serialize:
raise unittest.SkipTest(
"FlexAttentionBackward isn't marked cacheable even though it is"
)
class FlexAttentionModule(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super().__init__()

View File

@ -6618,7 +6618,7 @@ class TestLearnableBiases(InductorTestCase):
)
# Error in backwards
with self.assertRaisesRegex(
torch._inductor.exc.LoweringException,
torch._inductor.exc.InductorError,
"Using multiple indexing operations on the same tensor that requires gradients",
):
self._check_outputs_and_grads(

View File

@ -112,7 +112,7 @@ flex_attention = FlexAttentionHOP()
class FlexAttentionBackwardHOP(HigherOrderOperator):
def __init__(self) -> None:
super().__init__("flex_attention_backward")
super().__init__("flex_attention_backward", cacheable=True)
def __call__(
self,

View File

@ -971,9 +971,14 @@ def _compile_fx_inner(
else "FX cache disabled or key generation failed"
),
)
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
try:
mb_compiled_graph = fx_codegen_and_compile(
gm, example_inputs, inputs_to_check, **graph_kwargs
)
except Exception as e:
raise InductorError(e, currentframe()).with_traceback(
e.__traceback__
) from None
# CACHE MISS: Compile the graph and save to cache
elif cache_info["cache_state"] == "miss":