[Attention] Always pad in preprocess_mask to avoid recompilations (#150403)

Motivation: for the following script:

```
// demo.py
import torch
import json
from transformers import BertModel, BertConfig

CONFIG = """
{
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.6.0.dev0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}
"""

config = json.loads(CONFIG)
bloom_config = BertConfig(**config)
model = BertModel(bloom_config).half().cuda()

torch.compiler.reset()
torch.cuda.empty_cache()
compiled_fn = torch.compile(model)
vocab_size = 30522

for b in range(1, 3):
    for s in range(1, 10):
        print(f"🚀 {b} {s}")
        input_ids = torch.randint(0, vocab_size, (b, s)).cuda()
        attention_mask = torch.ones(b, s).cuda()

        with torch.no_grad():
            out = compiled_fn(input_ids, attention_mask).last_hidden_state
```

when we run it with:

```
time TORCH_LOGS=recompiles python demo.py
```

We can see there are 7 recompilations and it takes 2 mins (fresh build) or 1 min (cached build)  in my machine.

One root cause of the recompilations is, there are guards to check the alignments of the inputs (see the patch).  So there are unexpected recompilations for `(1, 4)`, `(1, 8)`, `(2, 4)` and `(2, 8)` inputs.

In this patch, we always try to always pad the inputs if we don't know its shape at compilation to avoid the guards on alignment. It is fine to always pad the tensor. It won't change the semantics.

Now there are only 3 recompilations and it takes 1 min (fresh build) and 17s (cached build) in my machine.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150403
Approved by: https://github.com/drisspg
This commit is contained in:
Chuanqi Xu 2025-04-14 04:18:19 +00:00 committed by PyTorch MergeBot
parent 9458b83729
commit 0d09a33819
2 changed files with 52 additions and 2 deletions

View File

@ -572,7 +572,13 @@ std::optional<Tensor> convert_boolean_attn_mask_cudnn(const std::optional<Tensor
template<int alignment>
bool aligned_tensor(const at::Tensor& tensor){
for(const auto i : c10::irange(tensor.dim() - 1)){
if(tensor.sym_stride(i) % alignment != 0){
auto stride = tensor.sym_stride(i).maybe_as_int();
// If the stride is unknown at compilation time, assume it is unaligned
// and always pad it. This is helpful to avoid unnecessary guards.
if (!stride)
return false;
if((*stride) % alignment != 0){
return false;
}
}

View File

@ -12,7 +12,7 @@ import torch.nn.functional as F
from torch._dynamo.comptime import comptime
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend, same
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import skipIfWindows
from torch.testing._internal.common_utils import requires_cuda, skipIfWindows
from torch.testing._internal.logging_utils import logs_to_string
@ -62,6 +62,50 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
@requires_cuda
def test_no_recompilations_with_efficient_attention(self):
def fn(q, k, v, attn_mask):
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.functional import scaled_dot_product_attention
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
return scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, scale=1.0
)
def make_q_k_v_mask(batch, num_heads, head_dim, seq_len_kv):
from collections import namedtuple
from functools import partial
dtype = torch.float16
device = "cuda"
make_tensor = partial(
torch.rand, device=device, dtype=dtype, requires_grad=True
)
seq_len_q = 64
SdpaShape = namedtuple(
"Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
)
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
mask = torch.randn(
(batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype
)
return query, key, value, mask
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
q, k, v, mask = make_q_k_v_mask(16, 16, 64, 15)
opt_fn(q, k, v, mask)
q, k, v, mask = make_q_k_v_mask(16, 16, 64, 16)
opt_fn(q, k, v, mask)
self.assertEqual(cnts.frame_count, 1)
@unittest.expectedFailure # array scalars decay to 0D arrays
def test_builtin_max_min(self):
# test unspecialized primitive max/min