mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Attention] Always pad in preprocess_mask to avoid recompilations (#150403)
Motivation: for the following script:
```
// demo.py
import torch
import json
from transformers import BertModel, BertConfig
CONFIG = """
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
"""
config = json.loads(CONFIG)
bloom_config = BertConfig(**config)
model = BertModel(bloom_config).half().cuda()
torch.compiler.reset()
torch.cuda.empty_cache()
compiled_fn = torch.compile(model)
vocab_size = 30522
for b in range(1, 3):
for s in range(1, 10):
print(f"🚀 {b} {s}")
input_ids = torch.randint(0, vocab_size, (b, s)).cuda()
attention_mask = torch.ones(b, s).cuda()
with torch.no_grad():
out = compiled_fn(input_ids, attention_mask).last_hidden_state
```
when we run it with:
```
time TORCH_LOGS=recompiles python demo.py
```
We can see there are 7 recompilations and it takes 2 mins (fresh build) or 1 min (cached build) in my machine.
One root cause of the recompilations is, there are guards to check the alignments of the inputs (see the patch). So there are unexpected recompilations for `(1, 4)`, `(1, 8)`, `(2, 4)` and `(2, 8)` inputs.
In this patch, we always try to always pad the inputs if we don't know its shape at compilation to avoid the guards on alignment. It is fine to always pad the tensor. It won't change the semantics.
Now there are only 3 recompilations and it takes 1 min (fresh build) and 17s (cached build) in my machine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150403
Approved by: https://github.com/drisspg
This commit is contained in:
parent
9458b83729
commit
0d09a33819
|
|
@ -572,7 +572,13 @@ std::optional<Tensor> convert_boolean_attn_mask_cudnn(const std::optional<Tensor
|
|||
template<int alignment>
|
||||
bool aligned_tensor(const at::Tensor& tensor){
|
||||
for(const auto i : c10::irange(tensor.dim() - 1)){
|
||||
if(tensor.sym_stride(i) % alignment != 0){
|
||||
auto stride = tensor.sym_stride(i).maybe_as_int();
|
||||
// If the stride is unknown at compilation time, assume it is unaligned
|
||||
// and always pad it. This is helpful to avoid unnecessary guards.
|
||||
if (!stride)
|
||||
return false;
|
||||
|
||||
if((*stride) % alignment != 0){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import torch.nn.functional as F
|
|||
from torch._dynamo.comptime import comptime
|
||||
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend, same
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_utils import skipIfWindows
|
||||
from torch.testing._internal.common_utils import requires_cuda, skipIfWindows
|
||||
from torch.testing._internal.logging_utils import logs_to_string
|
||||
|
||||
|
||||
|
|
@ -62,6 +62,50 @@ class UnspecTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
@requires_cuda
|
||||
def test_no_recompilations_with_efficient_attention(self):
|
||||
def fn(q, k, v, attn_mask):
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
||||
return scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, scale=1.0
|
||||
)
|
||||
|
||||
def make_q_k_v_mask(batch, num_heads, head_dim, seq_len_kv):
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
dtype = torch.float16
|
||||
device = "cuda"
|
||||
make_tensor = partial(
|
||||
torch.rand, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
seq_len_q = 64
|
||||
SdpaShape = namedtuple(
|
||||
"Sdpa_Shape", ["batch", "num_heads", "seq_len", "head_dim"]
|
||||
)
|
||||
query = make_tensor(SdpaShape(batch, num_heads, seq_len_q, head_dim))
|
||||
kv_shape = SdpaShape(batch, num_heads, seq_len_kv, head_dim)
|
||||
key, value = make_tensor(kv_shape), make_tensor(kv_shape)
|
||||
mask = torch.randn(
|
||||
(batch, num_heads, seq_len_q, seq_len_kv), device=device, dtype=dtype
|
||||
)
|
||||
|
||||
return query, key, value, mask
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
|
||||
q, k, v, mask = make_q_k_v_mask(16, 16, 64, 15)
|
||||
opt_fn(q, k, v, mask)
|
||||
|
||||
q, k, v, mask = make_q_k_v_mask(16, 16, 64, 16)
|
||||
opt_fn(q, k, v, mask)
|
||||
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
@unittest.expectedFailure # array scalars decay to 0D arrays
|
||||
def test_builtin_max_min(self):
|
||||
# test unspecialized primitive max/min
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user