From 516e58965aba2a98d633b93924655d5ae1184c22 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 25 Oct 2025 15:47:32 +0000 Subject: [PATCH] Revert "Export flex attention with kwargs and DTensor (#166045)" This reverts commit de7fdfe41ad12aec719e3662be58ce9e9bf255a8. Reverted https://github.com/pytorch/pytorch/pull/166045 on behalf of https://github.com/malfet due to Broke distributed tests, see https://hud.pytorch.org/hud/pytorch/pytorch/b55b779ad3062b91c64753132264a015378be506/1?per_page=50&name_filter=distributed&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/166045#issuecomment-3446850955)) --- .../distributed/tensor/test_dtensor_export.py | 138 +++--------------- test/inductor/test_flex_attention.py | 135 ----------------- torch/nn/attention/flex_attention.py | 54 +------ 3 files changed, 20 insertions(+), 307 deletions(-) diff --git a/test/distributed/tensor/test_dtensor_export.py b/test/distributed/tensor/test_dtensor_export.py index b04c9324cb9..1f25090e576 100644 --- a/test/distributed/tensor/test_dtensor_export.py +++ b/test/distributed/tensor/test_dtensor_export.py @@ -22,7 +22,6 @@ from torch.distributed.tensor.parallel import ( parallelize_module, RowwiseParallel, ) -from torch.nn.attention.flex_attention import create_block_mask, flex_attention from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -83,46 +82,7 @@ class SimpleModelAnnotated(torch.nn.Module): return self.mlp_1(x) -class FlexAttentionModel(torch.nn.Module): - def __init__(self, device): - super().__init__() - self.proj_q = torch.nn.Linear(16, 128, device=device) - self.proj_k = torch.nn.Linear(16, 128, device=device) - self.proj_v = torch.nn.Linear(16, 128, device=device) - self.proj_out = torch.nn.Linear(128, 16, device=device) - self.num_heads = 8 - self.head_dim = 16 - - def forward(self, x, *, block_mask=None): - batch_size, seq_len, embed_dim = x.shape - # Project to Q, K, V - q = self.proj_q(x) - k = self.proj_k(x) - v = self.proj_v(x) - # After colwise parallel, q/k/v are sharded on the last dimension - # Get the actual size after sharding - hidden_size = q.shape[-1] - num_heads_local = hidden_size // self.head_dim - # Reshape to (batch, num_heads, seq_len, head_dim) - q = q.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2) - k = k.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2) - v = v.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2) - # Apply flex_attention - attn_output_raw = flex_attention(q, k, v, block_mask=block_mask) - # Reshape back to (batch, seq_len, hidden_size) - attn_output = ( - attn_output_raw.transpose(1, 2) - .contiguous() - .view(batch_size, seq_len, hidden_size) - ) - # Output projection - output = self.proj_out(attn_output) - return output - - -def strict_export_and_aot_export_joint_with_descriptors(model, args, kwargs=None): - if kwargs is None: - kwargs = {} +def strict_export_and_aot_export_joint_with_descriptors(model, inputs): # needed for stric export torch.utils._pytree.register_constant(DTensorSpec) @@ -131,43 +91,36 @@ def strict_export_and_aot_export_joint_with_descriptors(model, args, kwargs=None install_free_tensors=True, inline_inbuilt_nn_modules=True ): with torch._export.utils._disable_aten_to_metadata_assertions(): - ep = torch.export.export(model, args, kwargs, strict=True) + ep = torch.export.export(model, (inputs,), strict=True) # joint_gm produced here is missing the backward region, due to incompatiblility # between ep.module() and aot_export_joint_with_descriptors. # Keeping this here to show the issue. - return aot_export_joint_with_descriptors_alone(ep.module(), args, kwargs) + return aot_export_joint_with_descriptors_alone(ep.module(), inputs) -def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=None): - if kwargs is None: - kwargs = {} - gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) +def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs): + gm = dynamo_graph_capture_for_export(model)(inputs) fake_mode = gm.meta.get("fake_mode", None) with tracing(TracingContext(fake_mode)): - return aot_export_joint_with_descriptors_alone(gm, args, kwargs) + return aot_export_joint_with_descriptors_alone(gm, inputs) -def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None): - if kwargs is None: - kwargs = {} +def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): with torch._dynamo.config.patch(install_free_tensors=True): # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs) + gm = _dynamo_graph_capture_for_export(model)(inputs) fake_mode = gm.meta.get("fake_mode", None) with tracing(TracingContext(fake_mode)): - return aot_export_joint_with_descriptors_alone(gm, args, kwargs) + return aot_export_joint_with_descriptors_alone(gm, inputs) -def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): - if kwargs is None: - kwargs = {} +def aot_export_joint_with_descriptors_alone(model, inputs): with contextlib.ExitStack() as stack: joint_with_descriptors = aot_export_joint_with_descriptors( stack, model, - args, - kwargs, + (inputs,), ) return joint_with_descriptors.graph_module @@ -215,8 +168,8 @@ class DTensorExportTest(TestCase): } tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan) - inp = torch.rand(20, 10, device=self.device_type) - inputs = (distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()]),) + inputs = torch.rand(20, 10, device=self.device_type) + inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()]) joint_gm = export_fn(tp_model, inputs) fw_gm, bw_gm = min_cut_rematerialization_partition( @@ -399,10 +352,9 @@ class DTensorExportTest(TestCase): } tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan) - inp = torch.rand(20, 10, device=self.device_type) - inp_dtensor = distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()]) - torch._dynamo.mark_dynamic(inp_dtensor, 0, min=5, max=100) - inputs = (inp_dtensor,) + inputs = torch.rand(20, 10, device=self.device_type) + inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()]) + torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100) joint_gm = export_fn(tp_model, inputs) @@ -438,67 +390,15 @@ class DTensorExportTest(TestCase): z = torch.randn(16, 16) y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()]) z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()]) - inputs = (x_dtensor, y_dtensor, z_dtensor) # Run model to verify it works - output = model(*inputs) + output = model(x_dtensor, y_dtensor, z_dtensor) with torch._dynamo.config.patch(install_free_tensors=True): # TODO: switch to use the official graph_capture API once it is ready - gm = export_fn(model)(*inputs) - output_gm = gm(*inputs) + gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor) + output_gm = gm(x_dtensor, y_dtensor, z_dtensor) self.assertEqual(output, output_gm) - def test_flex_attention_dtensor_export(self): - device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,)) - model = FlexAttentionModel(self.device_type) - - # Parallelize the model: shard on head dimension - # proj_q, proj_k, proj_v are colwise parallel (output is sharded on head dimension) - # proj_out is rowwise parallel (input is sharded, output needs reduction) - parallelize_plan = { - "proj_q": ColwiseParallel(), - "proj_k": ColwiseParallel(), - "proj_v": ColwiseParallel(), - "proj_out": RowwiseParallel(), - } - tp_model = parallelize_module(model, device_mesh, parallelize_plan) - batch_size = 4 - seq_len = 64 - embed_dim = 16 - num_heads = 8 - - # Input tensor replicated across all devices - inp = torch.randn(batch_size, seq_len, embed_dim, device=self.device_type) - inputs = (distribute_tensor(inp, device_mesh, placements=[Replicate()]),) - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask( - causal_mask, - batch_size, - num_heads, - seq_len, - seq_len, - device=self.device_type, - ) - - flex_kwargs = {"block_mask": block_mask} - - joint_gm = graph_capture_and_aot_export_joint_with_descriptors( - tp_model, inputs, flex_kwargs - ) - - self.assertTrue( - _count_op(joint_gm, torch.ops.higher_order.flex_attention), - 1, - ) - - self.assertTrue( - _count_op(joint_gm, torch.ops.higher_order.flex_attention_backward), - 2, - ) - instantiate_parametrized_tests(DTensorExportTest) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index f4b65356d3d..529bbaf8267 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5734,141 +5734,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3) - @supported_platform - def test_pytree_flatten_unflatten(self, device): - """Test that BlockMask can be correctly flattened and unflattened using class methods.""" - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - # Create a BlockMask with various attributes set - block_mask = create_block_mask( - causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device - ) - - # Flatten and unflatten using class methods - tensors, context = block_mask._flatten() - reconstructed_mask = BlockMask._unflatten(tensors, context) - - # Verify the reconstructed mask has the same attributes - self.assertEqual(reconstructed_mask.shape, block_mask.shape) - self.assertEqual(reconstructed_mask.sparsity(), block_mask.sparsity()) - - # Verify all tensor attributes are equal (using _TENSOR_ATTRS) - for attr_name in BlockMask._TENSOR_ATTRS: - original_value = getattr(block_mask, attr_name) - reconstructed_value = getattr(reconstructed_mask, attr_name) - - if original_value is None: - self.assertIsNone( - reconstructed_value, - f"Tensor attribute {attr_name} should be None but got {reconstructed_value}", - ) - else: - self.assertIsInstance( - original_value, - torch.Tensor, - f"Expected {attr_name} to be a Tensor", - ) - self.assertTrue( - torch.equal(original_value, reconstructed_value), - f"Tensor attribute {attr_name} not equal after reconstruction", - ) - - # Verify all context attributes are equal (using _CONTEXT_ATTRS) - for attr_name in BlockMask._CONTEXT_ATTRS: - original_value = getattr(block_mask, attr_name) - reconstructed_value = getattr(reconstructed_mask, attr_name) - - self.assertEqual( - original_value, - reconstructed_value, - f"Context attribute {attr_name} not equal after reconstruction", - ) - - @supported_platform - def test_pytree_flatten_with_keys(self, device): - """Test that BlockMask._flatten_with_keys works correctly for tracing.""" - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask( - causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device - ) - - tensors_with_keys, context_with_keys = block_mask._flatten_with_keys() - - self.assertEqual(len(tensors_with_keys), len(BlockMask._TENSOR_ATTRS)) - self.assertEqual(len(context_with_keys), len(BlockMask._CONTEXT_ATTRS)) - - from torch.utils._pytree import GetAttrKey - - for key, tensor in tensors_with_keys: - self.assertIsInstance(key, GetAttrKey) - self.assertIsNotNone(key) - - for key, value in context_with_keys: - self.assertIsInstance(key, GetAttrKey) - self.assertIsNotNone(key) - - @supported_platform - def test_pytree_preserves_new_attributes(self, device): - """ - Test that BlockMask._TENSOR_ATTRS and _CONTEXT_ATTRS are correctly defined - and that flatten/unflatten preserves all attributes in these lists. - - """ - - def causal_mask(b, h, q_idx, kv_idx): - return q_idx >= kv_idx - - block_mask = create_block_mask( - causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device - ) - - # Flatten and unflatten using class methods - tensors, context = block_mask._flatten() - reconstructed_mask = BlockMask._unflatten(tensors, context) - - # Verify the number of tensors and context values matches the attribute lists - self.assertEqual( - len(tensors), - len(BlockMask._TENSOR_ATTRS), - "Number of tensors should match _TENSOR_ATTRS length", - ) - self.assertEqual( - len(context), - len(BlockMask._CONTEXT_ATTRS), - "Number of context values should match _CONTEXT_ATTRS length", - ) - - # Verify all attributes from the lists exist and are equal after reconstruction - for attr_name in BlockMask._TENSOR_ATTRS + BlockMask._CONTEXT_ATTRS: - self.assertTrue( - hasattr(reconstructed_mask, attr_name), - f"Reconstructed mask missing attribute: {attr_name}", - ) - original_value = getattr(block_mask, attr_name) - reconstructed_value = getattr(reconstructed_mask, attr_name) - - if isinstance(original_value, torch.Tensor): - self.assertTrue( - torch.equal(original_value, reconstructed_value), - f"Tensor attribute {attr_name} not equal after reconstruction", - ) - elif original_value is None: - self.assertIsNone( - reconstructed_value, - f"Attribute {attr_name} should be None but got {reconstructed_value}", - ) - else: - self.assertEqual( - original_value, - reconstructed_value, - f"Attribute {attr_name} not equal after reconstruction", - ) - @large_tensor_test_class("2GB", device=test_device[0]) class TestPagedAttention(InductorTestCase): diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index b886fc4072f..b68b010ef43 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -34,7 +34,7 @@ from torch.fx.experimental.proxy_tensor import ( _temp_remove_pre_dispatch_torch_function_mode, ) from torch.nn.attention._utils import _validate_sdpa_input -from torch.utils._pytree import GetAttrKey, register_pytree_node, tree_map_only +from torch.utils._pytree import tree_map_only # Private debug flag to disable internal compilation wrapping for debugging purposes. @@ -519,24 +519,6 @@ class BlockMask: BLOCK_SIZE: tuple[int, int] mask_mod: _mask_mod_signature - # Attribute lists for pytree flatten/unflatten - _TENSOR_ATTRS = [ - "kv_num_blocks", - "kv_indices", - "full_kv_num_blocks", - "full_kv_indices", - "q_num_blocks", - "q_indices", - "full_q_num_blocks", - "full_q_indices", - ] - - _CONTEXT_ATTRS = [ - "seq_lengths", - "BLOCK_SIZE", - "mask_mod", - ] - def __init__( self, seq_lengths: tuple[int, int], @@ -931,31 +913,6 @@ class BlockMask: ) return BlockMask(*mapped_attributes) - def _flatten(self): - """Flatten BlockMask into a list of tensors and context.""" - tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS) - context = tuple(getattr(self, attr) for attr in self._CONTEXT_ATTRS) - return tensors, context - - @classmethod - def _unflatten(cls, tensors, context): - """Unflatten tensors and context back into a BlockMask.""" - kwargs = { - **dict(zip(cls._CONTEXT_ATTRS, context)), - **dict(zip(cls._TENSOR_ATTRS, tensors)), - } - return cls(**kwargs) - - def _flatten_with_keys(self): - """Flatten BlockMask with keys for better tracing.""" - tensors = tuple( - (GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS - ) - context = tuple( - (GetAttrKey(attr), getattr(self, attr)) for attr in self._CONTEXT_ATTRS - ) - return tensors, context - def _broadcast_to_dim(x, dim): while x.dim() < dim: @@ -1647,12 +1604,3 @@ def flex_attention( return _finalize_outputs( out, lse, max_scores, return_aux=return_aux, return_lse=return_lse ) - - -register_pytree_node( - BlockMask, - BlockMask._flatten, - BlockMask._unflatten, - flatten_with_keys_fn=BlockMask._flatten_with_keys, - serialized_type_name="torch.nn.attention.flex_attention.BlockMask", -)