mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "Export flex attention with kwargs and DTensor (#166045)"
This reverts commitde7fdfe41a. Reverted https://github.com/pytorch/pytorch/pull/166045 on behalf of https://github.com/malfet due to Broke distributed tests, seeb55b779ad3/1([comment](https://github.com/pytorch/pytorch/pull/166045#issuecomment-3446850955))
This commit is contained in:
parent
b55b779ad3
commit
516e58965a
|
|
@ -22,7 +22,6 @@ from torch.distributed.tensor.parallel import (
|
|||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
|
|
@ -83,46 +82,7 @@ class SimpleModelAnnotated(torch.nn.Module):
|
|||
return self.mlp_1(x)
|
||||
|
||||
|
||||
class FlexAttentionModel(torch.nn.Module):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.proj_q = torch.nn.Linear(16, 128, device=device)
|
||||
self.proj_k = torch.nn.Linear(16, 128, device=device)
|
||||
self.proj_v = torch.nn.Linear(16, 128, device=device)
|
||||
self.proj_out = torch.nn.Linear(128, 16, device=device)
|
||||
self.num_heads = 8
|
||||
self.head_dim = 16
|
||||
|
||||
def forward(self, x, *, block_mask=None):
|
||||
batch_size, seq_len, embed_dim = x.shape
|
||||
# Project to Q, K, V
|
||||
q = self.proj_q(x)
|
||||
k = self.proj_k(x)
|
||||
v = self.proj_v(x)
|
||||
# After colwise parallel, q/k/v are sharded on the last dimension
|
||||
# Get the actual size after sharding
|
||||
hidden_size = q.shape[-1]
|
||||
num_heads_local = hidden_size // self.head_dim
|
||||
# Reshape to (batch, num_heads, seq_len, head_dim)
|
||||
q = q.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, seq_len, num_heads_local, self.head_dim).transpose(1, 2)
|
||||
# Apply flex_attention
|
||||
attn_output_raw = flex_attention(q, k, v, block_mask=block_mask)
|
||||
# Reshape back to (batch, seq_len, hidden_size)
|
||||
attn_output = (
|
||||
attn_output_raw.transpose(1, 2)
|
||||
.contiguous()
|
||||
.view(batch_size, seq_len, hidden_size)
|
||||
)
|
||||
# Output projection
|
||||
output = self.proj_out(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
def strict_export_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
# needed for stric export
|
||||
torch.utils._pytree.register_constant(DTensorSpec)
|
||||
|
||||
|
|
@ -131,43 +91,36 @@ def strict_export_and_aot_export_joint_with_descriptors(model, args, kwargs=None
|
|||
install_free_tensors=True, inline_inbuilt_nn_modules=True
|
||||
):
|
||||
with torch._export.utils._disable_aten_to_metadata_assertions():
|
||||
ep = torch.export.export(model, args, kwargs, strict=True)
|
||||
ep = torch.export.export(model, (inputs,), strict=True)
|
||||
|
||||
# joint_gm produced here is missing the backward region, due to incompatiblility
|
||||
# between ep.module() and aot_export_joint_with_descriptors.
|
||||
# Keeping this here to show the issue.
|
||||
return aot_export_joint_with_descriptors_alone(ep.module(), args, kwargs)
|
||||
return aot_export_joint_with_descriptors_alone(ep.module(), inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
gm = dynamo_graph_capture_for_export(model)(*args, **kwargs)
|
||||
def graph_capture_and_aot_export_joint_with_descriptors_v2(model, inputs):
|
||||
gm = dynamo_graph_capture_for_export(model)(inputs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
def graph_capture_and_aot_export_joint_with_descriptors(model, inputs):
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = _dynamo_graph_capture_for_export(model)(*args, **kwargs)
|
||||
gm = _dynamo_graph_capture_for_export(model)(inputs)
|
||||
fake_mode = gm.meta.get("fake_mode", None)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
return aot_export_joint_with_descriptors_alone(gm, args, kwargs)
|
||||
return aot_export_joint_with_descriptors_alone(gm, inputs)
|
||||
|
||||
|
||||
def aot_export_joint_with_descriptors_alone(model, args, kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
def aot_export_joint_with_descriptors_alone(model, inputs):
|
||||
with contextlib.ExitStack() as stack:
|
||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||
stack,
|
||||
model,
|
||||
args,
|
||||
kwargs,
|
||||
(inputs,),
|
||||
)
|
||||
return joint_with_descriptors.graph_module
|
||||
|
||||
|
|
@ -215,8 +168,8 @@ class DTensorExportTest(TestCase):
|
|||
}
|
||||
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
|
||||
|
||||
inp = torch.rand(20, 10, device=self.device_type)
|
||||
inputs = (distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()]),)
|
||||
inputs = torch.rand(20, 10, device=self.device_type)
|
||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
fw_gm, bw_gm = min_cut_rematerialization_partition(
|
||||
|
|
@ -399,10 +352,9 @@ class DTensorExportTest(TestCase):
|
|||
}
|
||||
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
|
||||
|
||||
inp = torch.rand(20, 10, device=self.device_type)
|
||||
inp_dtensor = distribute_tensor(inp, mesh_2d["tp"], placements=[Replicate()])
|
||||
torch._dynamo.mark_dynamic(inp_dtensor, 0, min=5, max=100)
|
||||
inputs = (inp_dtensor,)
|
||||
inputs = torch.rand(20, 10, device=self.device_type)
|
||||
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
|
||||
torch._dynamo.mark_dynamic(inputs, 0, min=5, max=100)
|
||||
|
||||
joint_gm = export_fn(tp_model, inputs)
|
||||
|
||||
|
|
@ -438,67 +390,15 @@ class DTensorExportTest(TestCase):
|
|||
z = torch.randn(16, 16)
|
||||
y_dtensor = distribute_tensor(y, device_mesh, placements=[Replicate()])
|
||||
z_dtensor = DTensor.from_local(z, device_mesh, placements=[Partial()])
|
||||
inputs = (x_dtensor, y_dtensor, z_dtensor)
|
||||
|
||||
# Run model to verify it works
|
||||
output = model(*inputs)
|
||||
output = model(x_dtensor, y_dtensor, z_dtensor)
|
||||
with torch._dynamo.config.patch(install_free_tensors=True):
|
||||
# TODO: switch to use the official graph_capture API once it is ready
|
||||
gm = export_fn(model)(*inputs)
|
||||
output_gm = gm(*inputs)
|
||||
gm = export_fn(model)(x_dtensor, y_dtensor, z_dtensor)
|
||||
output_gm = gm(x_dtensor, y_dtensor, z_dtensor)
|
||||
self.assertEqual(output, output_gm)
|
||||
|
||||
def test_flex_attention_dtensor_export(self):
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape=(self.world_size,))
|
||||
model = FlexAttentionModel(self.device_type)
|
||||
|
||||
# Parallelize the model: shard on head dimension
|
||||
# proj_q, proj_k, proj_v are colwise parallel (output is sharded on head dimension)
|
||||
# proj_out is rowwise parallel (input is sharded, output needs reduction)
|
||||
parallelize_plan = {
|
||||
"proj_q": ColwiseParallel(),
|
||||
"proj_k": ColwiseParallel(),
|
||||
"proj_v": ColwiseParallel(),
|
||||
"proj_out": RowwiseParallel(),
|
||||
}
|
||||
tp_model = parallelize_module(model, device_mesh, parallelize_plan)
|
||||
batch_size = 4
|
||||
seq_len = 64
|
||||
embed_dim = 16
|
||||
num_heads = 8
|
||||
|
||||
# Input tensor replicated across all devices
|
||||
inp = torch.randn(batch_size, seq_len, embed_dim, device=self.device_type)
|
||||
inputs = (distribute_tensor(inp, device_mesh, placements=[Replicate()]),)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_mask,
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
seq_len,
|
||||
device=self.device_type,
|
||||
)
|
||||
|
||||
flex_kwargs = {"block_mask": block_mask}
|
||||
|
||||
joint_gm = graph_capture_and_aot_export_joint_with_descriptors(
|
||||
tp_model, inputs, flex_kwargs
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
_count_op(joint_gm, torch.ops.higher_order.flex_attention),
|
||||
1,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
_count_op(joint_gm, torch.ops.higher_order.flex_attention_backward),
|
||||
2,
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DTensorExportTest)
|
||||
|
||||
|
|
|
|||
|
|
@ -5734,141 +5734,6 @@ BlockMask(shape=(1,s1,s2048,s2048),ssparsity=46.88%,s
|
|||
|
||||
self.assertEqual(flex_output, sdpa_output, atol=1e-3, rtol=1e-3)
|
||||
|
||||
@supported_platform
|
||||
def test_pytree_flatten_unflatten(self, device):
|
||||
"""Test that BlockMask can be correctly flattened and unflattened using class methods."""
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
# Create a BlockMask with various attributes set
|
||||
block_mask = create_block_mask(
|
||||
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
|
||||
)
|
||||
|
||||
# Flatten and unflatten using class methods
|
||||
tensors, context = block_mask._flatten()
|
||||
reconstructed_mask = BlockMask._unflatten(tensors, context)
|
||||
|
||||
# Verify the reconstructed mask has the same attributes
|
||||
self.assertEqual(reconstructed_mask.shape, block_mask.shape)
|
||||
self.assertEqual(reconstructed_mask.sparsity(), block_mask.sparsity())
|
||||
|
||||
# Verify all tensor attributes are equal (using _TENSOR_ATTRS)
|
||||
for attr_name in BlockMask._TENSOR_ATTRS:
|
||||
original_value = getattr(block_mask, attr_name)
|
||||
reconstructed_value = getattr(reconstructed_mask, attr_name)
|
||||
|
||||
if original_value is None:
|
||||
self.assertIsNone(
|
||||
reconstructed_value,
|
||||
f"Tensor attribute {attr_name} should be None but got {reconstructed_value}",
|
||||
)
|
||||
else:
|
||||
self.assertIsInstance(
|
||||
original_value,
|
||||
torch.Tensor,
|
||||
f"Expected {attr_name} to be a Tensor",
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.equal(original_value, reconstructed_value),
|
||||
f"Tensor attribute {attr_name} not equal after reconstruction",
|
||||
)
|
||||
|
||||
# Verify all context attributes are equal (using _CONTEXT_ATTRS)
|
||||
for attr_name in BlockMask._CONTEXT_ATTRS:
|
||||
original_value = getattr(block_mask, attr_name)
|
||||
reconstructed_value = getattr(reconstructed_mask, attr_name)
|
||||
|
||||
self.assertEqual(
|
||||
original_value,
|
||||
reconstructed_value,
|
||||
f"Context attribute {attr_name} not equal after reconstruction",
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_pytree_flatten_with_keys(self, device):
|
||||
"""Test that BlockMask._flatten_with_keys works correctly for tracing."""
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
|
||||
)
|
||||
|
||||
tensors_with_keys, context_with_keys = block_mask._flatten_with_keys()
|
||||
|
||||
self.assertEqual(len(tensors_with_keys), len(BlockMask._TENSOR_ATTRS))
|
||||
self.assertEqual(len(context_with_keys), len(BlockMask._CONTEXT_ATTRS))
|
||||
|
||||
from torch.utils._pytree import GetAttrKey
|
||||
|
||||
for key, tensor in tensors_with_keys:
|
||||
self.assertIsInstance(key, GetAttrKey)
|
||||
self.assertIsNotNone(key)
|
||||
|
||||
for key, value in context_with_keys:
|
||||
self.assertIsInstance(key, GetAttrKey)
|
||||
self.assertIsNotNone(key)
|
||||
|
||||
@supported_platform
|
||||
def test_pytree_preserves_new_attributes(self, device):
|
||||
"""
|
||||
Test that BlockMask._TENSOR_ATTRS and _CONTEXT_ATTRS are correctly defined
|
||||
and that flatten/unflatten preserves all attributes in these lists.
|
||||
|
||||
"""
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_mask, B=2, H=4, Q_LEN=512, KV_LEN=512, device=device
|
||||
)
|
||||
|
||||
# Flatten and unflatten using class methods
|
||||
tensors, context = block_mask._flatten()
|
||||
reconstructed_mask = BlockMask._unflatten(tensors, context)
|
||||
|
||||
# Verify the number of tensors and context values matches the attribute lists
|
||||
self.assertEqual(
|
||||
len(tensors),
|
||||
len(BlockMask._TENSOR_ATTRS),
|
||||
"Number of tensors should match _TENSOR_ATTRS length",
|
||||
)
|
||||
self.assertEqual(
|
||||
len(context),
|
||||
len(BlockMask._CONTEXT_ATTRS),
|
||||
"Number of context values should match _CONTEXT_ATTRS length",
|
||||
)
|
||||
|
||||
# Verify all attributes from the lists exist and are equal after reconstruction
|
||||
for attr_name in BlockMask._TENSOR_ATTRS + BlockMask._CONTEXT_ATTRS:
|
||||
self.assertTrue(
|
||||
hasattr(reconstructed_mask, attr_name),
|
||||
f"Reconstructed mask missing attribute: {attr_name}",
|
||||
)
|
||||
original_value = getattr(block_mask, attr_name)
|
||||
reconstructed_value = getattr(reconstructed_mask, attr_name)
|
||||
|
||||
if isinstance(original_value, torch.Tensor):
|
||||
self.assertTrue(
|
||||
torch.equal(original_value, reconstructed_value),
|
||||
f"Tensor attribute {attr_name} not equal after reconstruction",
|
||||
)
|
||||
elif original_value is None:
|
||||
self.assertIsNone(
|
||||
reconstructed_value,
|
||||
f"Attribute {attr_name} should be None but got {reconstructed_value}",
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
original_value,
|
||||
reconstructed_value,
|
||||
f"Attribute {attr_name} not equal after reconstruction",
|
||||
)
|
||||
|
||||
|
||||
@large_tensor_test_class("2GB", device=test_device[0])
|
||||
class TestPagedAttention(InductorTestCase):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ from torch.fx.experimental.proxy_tensor import (
|
|||
_temp_remove_pre_dispatch_torch_function_mode,
|
||||
)
|
||||
from torch.nn.attention._utils import _validate_sdpa_input
|
||||
from torch.utils._pytree import GetAttrKey, register_pytree_node, tree_map_only
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
# Private debug flag to disable internal compilation wrapping for debugging purposes.
|
||||
|
|
@ -519,24 +519,6 @@ class BlockMask:
|
|||
BLOCK_SIZE: tuple[int, int]
|
||||
mask_mod: _mask_mod_signature
|
||||
|
||||
# Attribute lists for pytree flatten/unflatten
|
||||
_TENSOR_ATTRS = [
|
||||
"kv_num_blocks",
|
||||
"kv_indices",
|
||||
"full_kv_num_blocks",
|
||||
"full_kv_indices",
|
||||
"q_num_blocks",
|
||||
"q_indices",
|
||||
"full_q_num_blocks",
|
||||
"full_q_indices",
|
||||
]
|
||||
|
||||
_CONTEXT_ATTRS = [
|
||||
"seq_lengths",
|
||||
"BLOCK_SIZE",
|
||||
"mask_mod",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_lengths: tuple[int, int],
|
||||
|
|
@ -931,31 +913,6 @@ class BlockMask:
|
|||
)
|
||||
return BlockMask(*mapped_attributes)
|
||||
|
||||
def _flatten(self):
|
||||
"""Flatten BlockMask into a list of tensors and context."""
|
||||
tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS)
|
||||
context = tuple(getattr(self, attr) for attr in self._CONTEXT_ATTRS)
|
||||
return tensors, context
|
||||
|
||||
@classmethod
|
||||
def _unflatten(cls, tensors, context):
|
||||
"""Unflatten tensors and context back into a BlockMask."""
|
||||
kwargs = {
|
||||
**dict(zip(cls._CONTEXT_ATTRS, context)),
|
||||
**dict(zip(cls._TENSOR_ATTRS, tensors)),
|
||||
}
|
||||
return cls(**kwargs)
|
||||
|
||||
def _flatten_with_keys(self):
|
||||
"""Flatten BlockMask with keys for better tracing."""
|
||||
tensors = tuple(
|
||||
(GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS
|
||||
)
|
||||
context = tuple(
|
||||
(GetAttrKey(attr), getattr(self, attr)) for attr in self._CONTEXT_ATTRS
|
||||
)
|
||||
return tensors, context
|
||||
|
||||
|
||||
def _broadcast_to_dim(x, dim):
|
||||
while x.dim() < dim:
|
||||
|
|
@ -1647,12 +1604,3 @@ def flex_attention(
|
|||
return _finalize_outputs(
|
||||
out, lse, max_scores, return_aux=return_aux, return_lse=return_lse
|
||||
)
|
||||
|
||||
|
||||
register_pytree_node(
|
||||
BlockMask,
|
||||
BlockMask._flatten,
|
||||
BlockMask._unflatten,
|
||||
flatten_with_keys_fn=BlockMask._flatten_with_keys,
|
||||
serialized_type_name="torch.nn.attention.flex_attention.BlockMask",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user