mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Mark FlexAttentionBackward as cacheable (#165996)
This probably should have been marked cacheable a long time ago, no reason that it isn't. Test Plan: New regional inductor tests for test_flex_attention now are serializable. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165996 Approved by: https://github.com/oulgen, https://github.com/zou3519, https://github.com/drisspg
This commit is contained in:
parent
a60d9e1f6d
commit
e4c01011c2
|
|
@ -245,8 +245,7 @@ due to:
|
|||
Traceback (most recent call last):
|
||||
File "test_logging.py", line N, in throw
|
||||
raise AssertionError
|
||||
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
|
||||
LoweringException: AssertionError:
|
||||
torch._inductor.exc.InductorError: LoweringException: AssertionError:
|
||||
target: aten.round.default
|
||||
args[0]: TensorBox(StorageBox(
|
||||
InputBuffer(name='primals_1', layout=FixedLayout('cpu', torch.float32, size=[1000, 1000], stride=[1000, 1]))
|
||||
|
|
|
|||
|
|
@ -184,12 +184,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
|||
@requires_cuda_and_triton
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_flex_attention(self, serialize):
|
||||
if serialize:
|
||||
# TODO: Fixed in next PR
|
||||
raise unittest.SkipTest(
|
||||
"FlexAttentionBackward isn't marked cacheable even though it is"
|
||||
)
|
||||
|
||||
def _squared(score, b, h, m, n):
|
||||
return score * score
|
||||
|
||||
|
|
@ -229,11 +223,6 @@ class RegionalInductorTests(torch._inductor.test_case.TestCase):
|
|||
@requires_cuda_and_triton
|
||||
@parametrize("serialize", [False, True])
|
||||
def test_selective_ac_flex(self, serialize):
|
||||
if serialize:
|
||||
raise unittest.SkipTest(
|
||||
"FlexAttentionBackward isn't marked cacheable even though it is"
|
||||
)
|
||||
|
||||
class FlexAttentionModule(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_heads):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -6618,7 +6618,7 @@ class TestLearnableBiases(InductorTestCase):
|
|||
)
|
||||
# Error in backwards
|
||||
with self.assertRaisesRegex(
|
||||
torch._inductor.exc.LoweringException,
|
||||
torch._inductor.exc.InductorError,
|
||||
"Using multiple indexing operations on the same tensor that requires gradients",
|
||||
):
|
||||
self._check_outputs_and_grads(
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ flex_attention = FlexAttentionHOP()
|
|||
|
||||
class FlexAttentionBackwardHOP(HigherOrderOperator):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("flex_attention_backward")
|
||||
super().__init__("flex_attention_backward", cacheable=True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -971,9 +971,14 @@ def _compile_fx_inner(
|
|||
else "FX cache disabled or key generation failed"
|
||||
),
|
||||
)
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
try:
|
||||
mb_compiled_graph = fx_codegen_and_compile(
|
||||
gm, example_inputs, inputs_to_check, **graph_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
raise InductorError(e, currentframe()).with_traceback(
|
||||
e.__traceback__
|
||||
) from None
|
||||
|
||||
# CACHE MISS: Compile the graph and save to cache
|
||||
elif cache_info["cache_state"] == "miss":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user