pytorch/benchmarks/transformer/attention_bias_benchmarks.py
Aaron Orenstein 07669ed960 PEP585 update - benchmarks tools torchgen (#145101)
This is one of a series of PRs to update us to PEP585 (changing Dict -> dict, List -> list, etc).  Most of the PRs were completely automated with RUFF as follows:

Since RUFF UP006 is considered an "unsafe" fix first we need to enable unsafe fixes:

```
--- a/tools/linter/adapters/ruff_linter.py
+++ b/tools/linter/adapters/ruff_linter.py
@@ -313,6 +313,7 @@
                     "ruff",
                     "check",
                     "--fix-only",
+                    "--unsafe-fixes",
                     "--exit-zero",
                     *([f"--config={config}"] if config else []),
                     "--stdin-filename",
```

Then we need to tell RUFF to allow UP006 (as a final PR once all of these have landed this will be made permanent):

```
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,7 +40,7 @@

 [tool.ruff]
-target-version = "py38"
+target-version = "py39"
 line-length = 88
 src = ["caffe2", "torch", "torchgen", "functorch", "test"]

@@ -87,7 +87,6 @@
     "SIM116", # Disable Use a dictionary instead of consecutive `if` statements
     "SIM117",
     "SIM118",
-    "UP006", # keep-runtime-typing
     "UP007", # keep-runtime-typing
 ]
 select = [
```

Finally running `lintrunner -a --take RUFF` will fix up the deprecated uses.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145101
Approved by: https://github.com/bobrenjc93
2025-01-18 05:05:07 +00:00

250 lines
7.4 KiB
Python

import itertools
from dataclasses import asdict, dataclass
from functools import partial
from typing import Callable, Union
import numpy as np
from tabulate import tabulate
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.benchmark as benchmark
from torch.nn.attention.bias import CausalBias, CausalVariant
from torch.nn.parameter import Parameter
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(5):
func(*args, **kwargs)
t0 = benchmark.Timer(
stmt="func(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "func": func},
)
return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
@dataclass(frozen=True)
class ExperimentConfig:
batch_size: int
num_heads: int
q_seq_len: int
k_seq_len: int
embed_dim: int
dtype: torch.dtype
@property
def head_dim(self) -> int:
return self.embed_dim // self.num_heads
def asdict(self):
dict_obj = asdict(self)
dict_obj["head_dim"] = self.head_dim
return dict_obj
@dataclass(frozen=True)
class ExperimentResults:
materialized_mask_time: float
attn_mask_subclass_time: float
def get_entries(self) -> list:
return [
f"{self.materialized_mask_time:2f}",
f"{self.attn_mask_subclass_time:2f}",
]
@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
results: ExperimentResults
def get_entries(self) -> list:
return self.config.get_entries() + self.results.get_entries()
def generate_inputs(
batch_size, q_sequence_length, kv_sequence_length, embed_dim, dtype, device
):
q_shape = (batch_size, q_sequence_length, embed_dim)
kv_shape = (batch_size, kv_sequence_length, embed_dim)
make_q = partial(torch.rand, q_shape, device=device, dtype=dtype)
make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype)
return make_q(), make_kv(), make_kv()
class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, embed_dim, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.head_dim = embed_dim // num_heads
self.embed_dim = embed_dim
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
)
self.out_proj = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.num_heads = num_heads
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Union[torch.Tensor, CausalBias],
):
query_projected = F.linear(query, self.q_proj_weight)
key_projected = F.linear(key, self.k_proj_weight)
value_projected = F.linear(value, self.v_proj_weight)
query = query.view(
query_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
key = key.view(
key_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
value = value.view(
value_projected.size(0), -1, self.num_heads, self.head_dim
).transpose(1, 2)
attn = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=mask,
dropout_p=0.0,
)
attn = attn.transpose(1, 2).reshape(query.size(0), -1, self.embed_dim)
# Match return signature of nn.MHA
return F.linear(attn, self.out_proj)
def reset_parameters(self):
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
nn.init.constant_(self.out_proj, 0.0)
def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
device = torch.device("cuda")
composite_mha = CompositeMHA(
config.num_heads, config.embed_dim, device, config.dtype
)
composite_mha.reset_parameters()
query, key, value = generate_inputs(
config.batch_size,
config.q_seq_len,
config.k_seq_len,
config.embed_dim,
config.dtype,
device,
)
attn_mask = CausalBias(
CausalVariant.LOWER_RIGHT, config.q_seq_len, config.k_seq_len
)
attn_mask_tensor = attn_mask._materialize(device)
materialized_mask_time = benchmark_torch_function_in_microseconds(
composite_mha, query, key, value, attn_mask_tensor
)
attn_mask_subclass_time = benchmark_torch_function_in_microseconds(
composite_mha, query, key, value, attn_mask
)
torch.testing.assert_close(
composite_mha(query, key, value, attn_mask_tensor),
composite_mha(query, key, value, attn_mask),
)
return ExperimentResults(
materialized_mask_time=materialized_mask_time,
attn_mask_subclass_time=attn_mask_subclass_time,
)
def generate_experiment_configs() -> list[ExperimentConfig]:
batch_sizes = [1, 8, 16, 128]
num_heads = [16, 32]
q_kv_seq_lens = [(128, 256), (256, 416), (512, 4097), (1024, 2048), (1, 2048)]
embed_dims = [2048, 4096]
dtypes = [
torch.bfloat16,
]
all_configs = []
for bsz, heads, (q_seq_len, kv_seq_len), embed_dim, dtype in itertools.product(
batch_sizes, num_heads, q_kv_seq_lens, embed_dims, dtypes
):
all_configs.append(
ExperimentConfig(
batch_size=bsz,
num_heads=heads,
q_seq_len=q_seq_len,
k_seq_len=kv_seq_len,
embed_dim=embed_dim,
dtype=dtype,
)
)
return all_configs
def calculate_speedup(results: ExperimentResults) -> float:
return results.materialized_mask_time / results.attn_mask_subclass_time
def print_results(results: list[Experiment]):
# Calculate speedups
speedups = [calculate_speedup(r.results) for r in results]
# Find indices of max and min speedups
max_speedup_index = np.argmax(speedups)
min_speedup_index = np.argmin(speedups)
# Get the config dictionaries
max_config_dict = results[max_speedup_index].config.asdict()
min_config_dict = results[min_speedup_index].config.asdict()
# Create table data
table_data = [
{
"Type": "Average",
"Speedup": np.mean(speedups),
**dict.fromkeys(max_config_dict),
},
{"Type": "Max", "Speedup": speedups[max_speedup_index], **max_config_dict},
{"Type": "Min", "Speedup": speedups[min_speedup_index], **min_config_dict},
]
# Print table
print(tabulate(table_data, headers="keys", tablefmt="pretty"))
def main():
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
# Run one timing experiment comparing nn_mha vs composite_mha
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
print_results(results)
if __name__ == "__main__":
main()