mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add attention benchmarking numbers to pytorch operator microbenchmarks (#164155)
This pull request introduces a standardized YAML-based configuration system for transformer attention benchmarks, making it easier to run and manage comprehensive performance tests. It adds example configs, and a wrapper script to convert YAML configs into CLI arguments for the benchmark runner. #### Next Steps: CI Enablement: This change would further lead to running the attention ops in CI for regression tracking. #### Developer flow: (Run locally) `python score_mod.py --config configs/config_test.yaml` #### Enabling CI run: https://github.com/pytorch/pytorch/pull/165915 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164155 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
0d4992c170
commit
a9b29caeae
157
benchmarks/transformer/config_utils.py
Normal file
157
benchmarks/transformer/config_utils.py
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
"""Configuration utilities for parsing JSON and YAML config files."""
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
|
||||
def heads_input_type(s: str) -> tuple[int, int]:
|
||||
"""Convert string format 'Hq,Hkv' to tuple (Hq, Hkv)."""
|
||||
try:
|
||||
hq, hkv = map(int, s.split(","))
|
||||
return hq, hkv
|
||||
except Exception as e:
|
||||
raise ValueError("Heads must be Hq,Hkv") from e
|
||||
|
||||
|
||||
default_config = {
|
||||
"dynamic": False,
|
||||
"calculate_bwd": False,
|
||||
"dtype": "bfloat16",
|
||||
"b": [2, 8, 16],
|
||||
"nh": ["16,16", "16,2"],
|
||||
"s": [512, 1024, 4096],
|
||||
"d": [64, 128],
|
||||
"mods": ["noop", "causal", "alibi", "sliding_window"],
|
||||
"backend": ["efficient"],
|
||||
"max_autotune": False,
|
||||
"decoding": False,
|
||||
"kv_size": None,
|
||||
"throughput": True,
|
||||
"save_path": None,
|
||||
"output_json_for_dashboard": None,
|
||||
"benchmark_name": "PyTorch operator microbenchmark",
|
||||
}
|
||||
|
||||
|
||||
def load_config_file(config_path: str) -> dict:
|
||||
"""Load configuration from JSON or YAML file.
|
||||
|
||||
Automatically converts 'nh' field from strings to tuples.
|
||||
|
||||
Args:
|
||||
config_path: Path to the configuration file
|
||||
|
||||
Returns:
|
||||
Dictionary containing the configuration
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file doesn't exist
|
||||
ValueError: If config file format is invalid
|
||||
"""
|
||||
with open(config_path) as f:
|
||||
config_str = f.read()
|
||||
|
||||
# Try to load as JSON first
|
||||
try:
|
||||
config = json.loads(config_str)
|
||||
except json.JSONDecodeError:
|
||||
# Fall back to YAML parsing
|
||||
config = _parse_simple_yaml(config_str)
|
||||
|
||||
# Apply automatic conversions for 'nh' field
|
||||
if "nh" in config and isinstance(config["nh"], list):
|
||||
config["nh"] = [
|
||||
heads_input_type(h) if isinstance(h, str) else h for h in config["nh"]
|
||||
]
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _parse_simple_yaml(yaml_str: str) -> dict:
|
||||
"""Simple YAML parser for basic configs (without external dependencies).
|
||||
|
||||
Supports:
|
||||
- key: value pairs
|
||||
- booleans (true/false)
|
||||
- null values
|
||||
- integers and floats
|
||||
- strings (quoted and unquoted)
|
||||
- lists in JSON format [item1, item2, ...]
|
||||
- comments (lines starting with # or after #)
|
||||
|
||||
Args:
|
||||
yaml_str: YAML content as string
|
||||
|
||||
Returns:
|
||||
Dictionary containing parsed YAML content
|
||||
"""
|
||||
config = {}
|
||||
|
||||
for line in yaml_str.split("\n"):
|
||||
# Remove comments
|
||||
line = line.split("#")[0].strip()
|
||||
|
||||
if not line or ":" not in line:
|
||||
continue
|
||||
|
||||
key, value = line.split(":", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Parse value based on type
|
||||
if value.lower() == "true":
|
||||
config[key] = True
|
||||
elif value.lower() == "false":
|
||||
config[key] = False
|
||||
elif value.lower() in ("null", "none", ""):
|
||||
config[key] = None
|
||||
elif value.startswith("[") and value.endswith("]"):
|
||||
# Parse list - handle quoted strings properly
|
||||
pattern = r'"([^"]+)"|\'([^\']+)\'|([^,\[\]\s]+)'
|
||||
matches = re.findall(pattern, value[1:-1]) # Remove [ ]
|
||||
parsed_items = []
|
||||
for match in matches:
|
||||
# match is a tuple of (double_quoted, single_quoted, unquoted)
|
||||
item = match[0] or match[1] or match[2]
|
||||
item = item.strip()
|
||||
if item:
|
||||
try:
|
||||
parsed_items.append(int(item))
|
||||
except ValueError:
|
||||
parsed_items.append(item)
|
||||
config[key] = parsed_items
|
||||
elif value.startswith(('"', "'")):
|
||||
config[key] = value.strip("\"'")
|
||||
else:
|
||||
# Try to parse as number
|
||||
try:
|
||||
config[key] = int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
config[key] = float(value)
|
||||
except ValueError:
|
||||
config[key] = value
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def print_default_config(output_format: str) -> None:
|
||||
"""Print a default configuration template in JSON or YAML format.
|
||||
|
||||
Args:
|
||||
output_format: Either "json" or "yaml"
|
||||
"""
|
||||
if output_format == "json":
|
||||
print(json.dumps(default_config, indent=2))
|
||||
else: # yaml
|
||||
for key, value in default_config.items():
|
||||
if value is None:
|
||||
print(f"{key}: null")
|
||||
elif isinstance(value, bool):
|
||||
print(f"{key}: {str(value).lower()}")
|
||||
elif isinstance(value, str):
|
||||
print(f'{key}: "{value}"')
|
||||
elif isinstance(value, list):
|
||||
print(f"{key}: {json.dumps(value)}")
|
||||
else:
|
||||
print(f"{key}: {value}")
|
||||
29
benchmarks/transformer/configs/config_basic.yaml
Normal file
29
benchmarks/transformer/configs/config_basic.yaml
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# Basic benchmark configuration for PyTorch transformer benchmarks
|
||||
# Usage: python score_mod.py --config config_basic.yaml
|
||||
|
||||
# Core parameters
|
||||
dynamic: false
|
||||
calculate_bwd: true
|
||||
dtype: "bfloat16"
|
||||
|
||||
# Shape parameters - larger sweep
|
||||
b: [1, 2, 4, 8, 16] # batch sizes
|
||||
nh: ["16,16", "16,2", "32,32", "32,4"] # [query_heads,key_value_heads]
|
||||
s: [512, 1024, 2048, 4096, 8192] # sequence lengths
|
||||
d: [64, 128] # head dimensions (limited to 128 for Flash Attention/cuDNN compatibility)
|
||||
|
||||
# All attention types
|
||||
mods: ["noop", "causal", "rel", "head_bias", "alibi", "sliding_window", "prefix_lm", "softcap"]
|
||||
|
||||
# Multiple backends for comparison (SDPA + Flash Attention) - flex is always included internally
|
||||
backend: ["efficient", "math", "cudnn", "fav2"]
|
||||
max_autotune: true # Enable torch.compile with max-autotune for optimal performance
|
||||
|
||||
# Decoding and cache settings
|
||||
decoding: false
|
||||
kv_size: null
|
||||
|
||||
# Metrics and output
|
||||
throughput: true # Calculate memory bandwidth & TFLOPS
|
||||
save_path: "comprehensive_results.csv" # Save to CSV
|
||||
output_json_for_dashboard: "attn_bench_basic.json"
|
||||
|
|
@ -1,15 +1,19 @@
|
|||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import asdict, dataclass
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
from functools import partial, wraps
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from config_utils import heads_input_type, load_config_file, print_default_config
|
||||
from tabulate import tabulate
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -33,6 +37,96 @@ torch._dynamo.config.recompile_limit = 1000
|
|||
from torch._inductor.runtime.benchmarking import benchmarker
|
||||
|
||||
|
||||
def cleanup_memory():
|
||||
"""Aggressively free GPU memory"""
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def safe_backend(backend_name=None, return_dict=False):
|
||||
"""Decorator that wraps backend functions with error handling
|
||||
|
||||
Args:
|
||||
backend_name: Name of the backend for error messages
|
||||
return_dict: If True, returns dict of results for all backends (for run_single_experiment)
|
||||
If False, returns single ExperimentResults (for individual backend functions)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(config, *args, **kwargs):
|
||||
try:
|
||||
return func(config, *args, **kwargs)
|
||||
except torch.OutOfMemoryError:
|
||||
print(
|
||||
f"[SKIP] OOM for {backend_name or func.__name__} with shape {config.shape}"
|
||||
)
|
||||
cleanup_memory()
|
||||
except RuntimeError as e:
|
||||
error_msg = str(e)
|
||||
if "out of resource" in error_msg or "OutOfMemoryError" in error_msg:
|
||||
print(
|
||||
f"[SKIP] Triton OOM for {backend_name or func.__name__} with shape {config.shape}"
|
||||
)
|
||||
cleanup_memory()
|
||||
elif "No valid triton configs" in error_msg:
|
||||
print(
|
||||
f"[SKIP] No valid Triton config for {backend_name or func.__name__} with shape {config.shape}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"[SKIP] Runtime error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[SKIP] Error for {backend_name or func.__name__} with shape {config.shape}: {str(e)[:100]}"
|
||||
)
|
||||
|
||||
# Return appropriate NaN result based on function type
|
||||
if return_dict:
|
||||
# For run_single_experiment: return dict with NaN for all backends
|
||||
nan_result = ExperimentResults(
|
||||
fwd_time=float("nan"),
|
||||
bwd_time=float("nan") if config.calculate_bwd_time else None,
|
||||
)
|
||||
results = dict.fromkeys(config.backends, nan_result)
|
||||
results["flex"] = ExperimentResults(
|
||||
fwd_time=float("nan"),
|
||||
bwd_time=float("nan") if config.calculate_bwd_time else None,
|
||||
sparsity=None,
|
||||
)
|
||||
return results
|
||||
else:
|
||||
# For individual backend functions: return single ExperimentResults
|
||||
return ExperimentResults(
|
||||
fwd_time=float("nan"),
|
||||
bwd_time=float("nan") if config.calculate_bwd_time else None,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Type definitions
|
||||
Backend = Literal["math", "efficient", "cudnn", "fav2", "fav3", "fakv", "og-eager"]
|
||||
AttentionType = Literal[
|
||||
"noop",
|
||||
"causal",
|
||||
"rel",
|
||||
"head_bias",
|
||||
"alibi",
|
||||
"sliding_window",
|
||||
"document_mask",
|
||||
"prefix_lm",
|
||||
"softcap",
|
||||
]
|
||||
DtypeString = Literal["bfloat16", "float16", "float32"]
|
||||
SpeedupType = Literal["fwd", "bwd"]
|
||||
|
||||
|
||||
def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
|
||||
# warmup
|
||||
for _ in range(5):
|
||||
|
|
@ -48,6 +142,7 @@ class ExperimentConfig:
|
|||
calculate_bwd_time: bool
|
||||
cal_bandwidth: bool
|
||||
backends: list[str]
|
||||
max_autotune: bool
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.shape) == 6, (
|
||||
|
|
@ -62,6 +157,7 @@ class ExperimentConfig:
|
|||
d.pop("cal_bandwidth", None)
|
||||
d["shape(B,Hq,M,Hkv,N,D)"] = d.pop("shape")
|
||||
d.pop("backends", None)
|
||||
d.pop("max_autotune", False)
|
||||
return d
|
||||
|
||||
|
||||
|
|
@ -209,6 +305,7 @@ def query_key_value_clones(
|
|||
return query_ref, key_ref, value_ref
|
||||
|
||||
|
||||
@safe_backend("SDPA")
|
||||
def run_single_backend_sdpa(
|
||||
config: ExperimentConfig,
|
||||
query: torch.Tensor,
|
||||
|
|
@ -223,6 +320,7 @@ def run_single_backend_sdpa(
|
|||
backend_context = get_backend_context(backend)
|
||||
with backend_context:
|
||||
_device = torch.device("cuda")
|
||||
|
||||
eager_sdpa = generate_eager_sdpa(
|
||||
config.attn_type, config.shape, config.dtype, block_mask, score_mod
|
||||
)
|
||||
|
|
@ -290,6 +388,7 @@ def run_single_backend_sdpa(
|
|||
)
|
||||
|
||||
|
||||
@safe_backend("FlashAttention")
|
||||
def run_single_backend_FA(
|
||||
config: ExperimentConfig,
|
||||
query: torch.Tensor,
|
||||
|
|
@ -301,9 +400,9 @@ def run_single_backend_FA(
|
|||
mask_kwargs,
|
||||
backend: str,
|
||||
) -> ExperimentResults:
|
||||
assert backend in ["fav2", "fav3", "fakv"]
|
||||
assert backend in ["fav3", "fakv"]
|
||||
# Generate callable for specific backend.
|
||||
if backend in ["fav2", "fav3"]:
|
||||
if backend in ["fav3"]:
|
||||
FA = generate_FA_callable(
|
||||
config.attn_type, config.shape, config.dtype, backend, **mask_kwargs
|
||||
)
|
||||
|
|
@ -354,10 +453,10 @@ def run_single_backend_FA(
|
|||
)
|
||||
|
||||
|
||||
@safe_backend("flex_attention", return_dict=True)
|
||||
def run_single_experiment(
|
||||
config: ExperimentConfig,
|
||||
dynamic=False,
|
||||
max_autotune=False,
|
||||
) -> dict[str, ExperimentResults]:
|
||||
device = torch.device("cuda")
|
||||
batch_size, q_heads, q_seq_len, kv_heads, kv_seq_len, head_dim = config.shape
|
||||
|
|
@ -377,7 +476,7 @@ def run_single_experiment(
|
|||
block_mask, mask_kwargs = generate_block_mask(config.attn_type, config.shape)
|
||||
kernel_options = get_kernel_options(config.attn_type, config.shape)
|
||||
|
||||
if max_autotune:
|
||||
if config.max_autotune:
|
||||
compiled_sdpa = torch.compile(
|
||||
flex_attention, dynamic=dynamic, mode="max-autotune-no-cudagraphs"
|
||||
)
|
||||
|
|
@ -407,7 +506,7 @@ def run_single_experiment(
|
|||
|
||||
results = {}
|
||||
for backend in config.backends:
|
||||
if backend in ["fav2", "fav3", "fakv"]:
|
||||
if backend in ["fav3", "fakv"]:
|
||||
results[backend] = run_single_backend_FA(
|
||||
config,
|
||||
query,
|
||||
|
|
@ -419,7 +518,7 @@ def run_single_experiment(
|
|||
mask_kwargs,
|
||||
backend,
|
||||
)
|
||||
else: # sdpa
|
||||
else: # sdpa (also supports fav2)
|
||||
results[backend] = run_single_backend_sdpa(
|
||||
config,
|
||||
query,
|
||||
|
|
@ -440,7 +539,7 @@ def run_single_experiment(
|
|||
sparsity = block_mask.sparsity() / 100.0 if block_mask is not None else 0.0
|
||||
sparsity = sparsity if config.attn_type != "document_mask" else 0.5
|
||||
|
||||
results["compiled"] = ExperimentResults(
|
||||
results["flex"] = ExperimentResults(
|
||||
fwd_time=forward_compiled_time,
|
||||
bwd_time=backward_compile_time if config.calculate_bwd_time else None,
|
||||
sparsity=sparsity,
|
||||
|
|
@ -501,15 +600,15 @@ def calculate_tflops(config: ExperimentConfig, results: ExperimentResults) -> fl
|
|||
softmax_flops = M * N * 2 # Not counting online softmax overhead
|
||||
o_flops = M * D * N * 2
|
||||
# Not counting split k overhead
|
||||
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - results.sparsity)
|
||||
sparsity = results.sparsity if results.sparsity is not None else 0.0
|
||||
total_flops = B * Hq * (qk_flops + softmax_flops + o_flops) * (1 - sparsity)
|
||||
return total_flops / results.fwd_time / 1e6 # in TFLOPs/
|
||||
|
||||
|
||||
def get_average_speedups(results: list[Experiment], type: str, backend: str):
|
||||
# Calculate speedups
|
||||
speedups = [
|
||||
calculate_speedup(r.results["compiled"], r.results[backend], type)
|
||||
for r in results
|
||||
calculate_speedup(r.results["flex"], r.results[backend], type) for r in results
|
||||
]
|
||||
|
||||
# Find indices of max and min speedups
|
||||
|
|
@ -537,7 +636,7 @@ def get_average_speedups(results: list[Experiment], type: str, backend: str):
|
|||
def print_results(results: list[Experiment], save_path: Optional[str] = None):
|
||||
table_data = defaultdict(list)
|
||||
for experiment in results:
|
||||
backends = experiment.config.backends + ["compiled"]
|
||||
backends = experiment.config.backends + ["flex"]
|
||||
for key, value in experiment.asdict().items():
|
||||
if key in backends:
|
||||
if value.fwd_time:
|
||||
|
|
@ -550,45 +649,43 @@ def print_results(results: list[Experiment], save_path: Optional[str] = None):
|
|||
# Calculate speedups
|
||||
for backend in results[0].config.backends:
|
||||
fwd_speedups = [
|
||||
calculate_speedup(r.results["compiled"], r.results[backend], type="fwd")
|
||||
calculate_speedup(r.results["flex"], r.results[backend], type="fwd")
|
||||
for r in results
|
||||
]
|
||||
table_data[f"fwd_{backend}_speedup"] = fwd_speedups
|
||||
table_data[f"fwd_speedup_flex_over_{backend}"] = fwd_speedups
|
||||
|
||||
if results[0].config.calculate_bwd_time:
|
||||
for backend in results[0].config.backends:
|
||||
bwd_speedups = [
|
||||
calculate_speedup(r.results["compiled"], r.results[backend], type="bwd")
|
||||
calculate_speedup(r.results["flex"], r.results[backend], type="bwd")
|
||||
for r in results
|
||||
]
|
||||
table_data[f"bwd_{backend}_speedup"] = bwd_speedups
|
||||
table_data[f"bwd_speedup_flex_over_{backend}"] = bwd_speedups
|
||||
|
||||
# Calculate mem + computational throughput
|
||||
if results[0].config.cal_bandwidth:
|
||||
fwd_bandwidth = [
|
||||
calculate_bandwidth(r.config, r.results["compiled"], type="fwd")
|
||||
calculate_bandwidth(r.config, r.results["flex"], type="fwd")
|
||||
for r in results
|
||||
]
|
||||
table_data["fwd_mem_bw (TB/s)"] = fwd_bandwidth
|
||||
fwd_tflops = [
|
||||
calculate_tflops(r.config, r.results["compiled"]) for r in results
|
||||
]
|
||||
fwd_tflops = [calculate_tflops(r.config, r.results["flex"]) for r in results]
|
||||
table_data["TFlops/s"] = fwd_tflops
|
||||
|
||||
print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
for backend in results[0].config.backends:
|
||||
if np.isnan(table_data[f"fwd_{backend}_speedup"]).all():
|
||||
if np.isnan(table_data[f"fwd_speedup_flex_over_{backend}"]).all():
|
||||
continue
|
||||
print("\n")
|
||||
print(f"FWD Speedups vs. {backend}".center(125, "="))
|
||||
print(f"FWD Speedup of Flex over {backend}".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="fwd", backend=backend)
|
||||
print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f"))
|
||||
|
||||
if results[0].config.calculate_bwd_time:
|
||||
print("\n")
|
||||
print(f"BWD Speedups vs. {backend}".center(125, "="))
|
||||
print(f"BWD Speedup of Flex over {backend}".center(125, "="))
|
||||
print("\n")
|
||||
average_data = get_average_speedups(results, type="bwd", backend=backend)
|
||||
print(
|
||||
|
|
@ -791,14 +888,14 @@ def get_backend_context(backend: str):
|
|||
Returns a context manager for the specified backend.
|
||||
Args:
|
||||
backend (str): The name of the backend to use.
|
||||
Valid options are 'fav2', 'cudnn', 'math', 'efficient', 'fav3', 'fakv', 'og-eager'.
|
||||
Valid options are 'math', 'efficient', 'cudnn', 'fav2', 'fav3', 'fakv', 'og-eager'.
|
||||
Returns:
|
||||
A context manager for the specified backend.
|
||||
Raises:
|
||||
ValueError: If an invalid backend is specified.
|
||||
"""
|
||||
backends = {
|
||||
"fav2": nullcontext(),
|
||||
"fav2": sdpa_kernel(SDPBackend.FLASH_ATTENTION),
|
||||
"cudnn": sdpa_kernel(SDPBackend.CUDNN_ATTENTION),
|
||||
"math": sdpa_kernel(SDPBackend.MATH),
|
||||
"efficient": sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION),
|
||||
|
|
@ -820,15 +917,7 @@ def generate_FA_callable(
|
|||
) -> Callable | None:
|
||||
if dtype not in [torch.float16, torch.bfloat16]:
|
||||
return None
|
||||
if backend == "fav2":
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
except ImportError:
|
||||
print(
|
||||
"Flash attention 2 is not installed. Please install it to run fav2 backend. "
|
||||
)
|
||||
raise
|
||||
elif backend == "fav3":
|
||||
if backend == "fav3":
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import (
|
||||
flash_attn_func,
|
||||
|
|
@ -1034,6 +1123,7 @@ def generate_experiment_configs(
|
|||
kv_cache_size: list[int],
|
||||
cal_bandwidth: bool,
|
||||
backends: list[str],
|
||||
max_autotune: bool,
|
||||
) -> list[ExperimentConfig]:
|
||||
assert not (calculate_bwd and decoding), "Decoding does not support backward"
|
||||
|
||||
|
|
@ -1077,52 +1167,333 @@ def generate_experiment_configs(
|
|||
calculate_bwd_time=calculate_bwd,
|
||||
cal_bandwidth=cal_bandwidth,
|
||||
backends=backends,
|
||||
max_autotune=max_autotune,
|
||||
)
|
||||
)
|
||||
|
||||
return all_configs
|
||||
|
||||
|
||||
def main(args):
|
||||
def _output_json_for_dashboard(
|
||||
experiments,
|
||||
output_file,
|
||||
benchmark_name="PyTorch operator microbenchmark",
|
||||
):
|
||||
"""
|
||||
Write the result into JSON format for PyTorch OSS dashboard.
|
||||
The JSON format is defined at
|
||||
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||
|
||||
Args:
|
||||
experiments: List of experiment results
|
||||
output_file: Path to output JSON file
|
||||
benchmark_name: Name of the benchmark
|
||||
"""
|
||||
if not experiments:
|
||||
return
|
||||
|
||||
import math
|
||||
import platform
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
# Prepare headers and records for JSON output
|
||||
records = []
|
||||
for experiment in experiments:
|
||||
config = experiment.config
|
||||
results_dict = (
|
||||
experiment.results
|
||||
) # This is a dict: backend -> ExperimentResults
|
||||
|
||||
# Process each backend result
|
||||
for backend, results in results_dict.items():
|
||||
# Skip backends that were not run (NaN results)
|
||||
if math.isnan(results.fwd_time):
|
||||
continue
|
||||
|
||||
# Extract data from experiment
|
||||
test_name = f"{backend}_{config.attn_type}_"
|
||||
input_config = f"shape: {config.shape}, dtype: {config.dtype}"
|
||||
|
||||
# Determine mode based on backward pass
|
||||
mode = "training" if config.calculate_bwd_time else "inference"
|
||||
|
||||
# Extract dtype
|
||||
dtype = (
|
||||
str(config.dtype).split(".")[1]
|
||||
if "." in str(config.dtype)
|
||||
else str(config.dtype)
|
||||
)
|
||||
|
||||
# Determine device
|
||||
device = "cuda"
|
||||
|
||||
# Get device architecture
|
||||
device_arch = (
|
||||
torch.cuda.get_device_name(0)
|
||||
if device == "cuda"
|
||||
else platform.processor()
|
||||
if device == "cpu"
|
||||
else "unknown"
|
||||
)
|
||||
|
||||
# Create dataclasses for JSON structure
|
||||
@dataclass
|
||||
class BenchmarkInfo:
|
||||
name: str
|
||||
mode: Optional[str]
|
||||
dtype: str
|
||||
extra_info: dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
name: str
|
||||
type: str
|
||||
origins: list[str]
|
||||
extra_info: dict[str, Any]
|
||||
|
||||
@dataclass
|
||||
class MetricInfo:
|
||||
name: str
|
||||
unit: str
|
||||
benchmark_values: list[float]
|
||||
target_value: Optional[float]
|
||||
|
||||
@dataclass
|
||||
class BenchmarkRecord:
|
||||
benchmark: BenchmarkInfo
|
||||
model: ModelInfo
|
||||
metric: MetricInfo
|
||||
|
||||
# Benchmark extra info
|
||||
benchmark_extra_info = {
|
||||
"input_config": input_config,
|
||||
"device": device,
|
||||
"arch": device_arch,
|
||||
"operator_name": backend,
|
||||
"attn_type": config.attn_type,
|
||||
"shape": str(config.shape),
|
||||
"max_autotune": config.max_autotune,
|
||||
}
|
||||
# Add record for forward latency
|
||||
record_fwd_latency = BenchmarkRecord(
|
||||
benchmark=BenchmarkInfo(
|
||||
name=benchmark_name,
|
||||
mode=mode,
|
||||
dtype=dtype,
|
||||
extra_info=benchmark_extra_info,
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name + str(config.shape),
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
"attn_type": config.attn_type,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="forward latency",
|
||||
unit="us",
|
||||
benchmark_values=[results.fwd_time],
|
||||
target_value=None,
|
||||
),
|
||||
)
|
||||
records.append(asdict(record_fwd_latency))
|
||||
|
||||
# Add record for forward memory bandwidth (if available)
|
||||
if config.cal_bandwidth:
|
||||
record_fwd_bandwidth = BenchmarkRecord(
|
||||
benchmark=BenchmarkInfo(
|
||||
name=benchmark_name,
|
||||
mode=mode,
|
||||
dtype=dtype,
|
||||
extra_info=benchmark_extra_info,
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name + str(config.shape),
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="memory bandwidth",
|
||||
unit="TB/s",
|
||||
benchmark_values=[calculate_bandwidth(config, results, "fwd")],
|
||||
target_value=None,
|
||||
),
|
||||
)
|
||||
records.append(asdict(record_fwd_bandwidth))
|
||||
|
||||
# Add record for forward TFLOPS (if available)
|
||||
if config.cal_bandwidth:
|
||||
record_fwd_tflops = BenchmarkRecord(
|
||||
benchmark=BenchmarkInfo(
|
||||
name=benchmark_name,
|
||||
mode=mode,
|
||||
dtype=dtype,
|
||||
extra_info=benchmark_extra_info,
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name + str(config.shape),
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="tflops",
|
||||
unit="TFLOPS/s",
|
||||
benchmark_values=[calculate_tflops(config, results)],
|
||||
target_value=None,
|
||||
),
|
||||
)
|
||||
records.append(asdict(record_fwd_tflops))
|
||||
|
||||
# Add record for backward latency (if available and not NaN)
|
||||
if (
|
||||
config.calculate_bwd_time
|
||||
and results.bwd_time is not None
|
||||
and not math.isnan(results.bwd_time)
|
||||
):
|
||||
record_bwd_latency = BenchmarkRecord(
|
||||
benchmark=BenchmarkInfo(
|
||||
name=benchmark_name,
|
||||
mode=mode,
|
||||
dtype=dtype,
|
||||
extra_info=benchmark_extra_info,
|
||||
),
|
||||
model=ModelInfo(
|
||||
name=test_name + str(config.shape),
|
||||
type="attention-benchmark",
|
||||
origins=["pytorch"],
|
||||
extra_info={
|
||||
"operator_name": backend,
|
||||
},
|
||||
),
|
||||
metric=MetricInfo(
|
||||
name="backward latency",
|
||||
unit="us",
|
||||
benchmark_values=[results.bwd_time],
|
||||
target_value=None,
|
||||
),
|
||||
)
|
||||
records.append(asdict(record_bwd_latency))
|
||||
|
||||
# Write all records to the output file
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(records, f, indent=2)
|
||||
|
||||
|
||||
def main(
|
||||
dynamic: bool = False,
|
||||
calculate_bwd: bool = False,
|
||||
dtype: DtypeString = "bfloat16",
|
||||
b: list[int] | None = None,
|
||||
nh: list[str] | None = None,
|
||||
s: list[int] | None = None,
|
||||
d: list[int] | None = None,
|
||||
mods: list[AttentionType] | None = None,
|
||||
backend: list[Backend] | None = None,
|
||||
max_autotune: bool = False,
|
||||
decoding: bool = False,
|
||||
kv_size: Optional[list[int]] = None,
|
||||
throughput: bool = True,
|
||||
save_path: Optional[str] = None,
|
||||
output_json_for_dashboard: Optional[str] = None,
|
||||
benchmark_name: str = "PyTorch operator microbenchmark",
|
||||
) -> None:
|
||||
"""Run sweep over sizes and score mods for flex attention.
|
||||
|
||||
Usage Examples:
|
||||
# Use a yml config file
|
||||
python score_mod.py --config basic_config.yaml
|
||||
|
||||
# Use a json config file
|
||||
python score_mod.py --config my_config.json
|
||||
|
||||
# Generate a config template
|
||||
python score_mod.py --print-config json > my_config.json # For a json config
|
||||
python score_mod.py --print-config yaml > my_config.yaml # For a yaml config
|
||||
|
||||
# Override config with CLI args
|
||||
python score_mod.py --config my_config.json -dtype float16 --max-autotune
|
||||
|
||||
# Pure CLI usage
|
||||
python score_mod.py -b 4 8 -s 1024 2048 -mods causal alibi --backend efficient
|
||||
|
||||
Args:
|
||||
dynamic: Runs a dynamic shapes version of compiled flex attention
|
||||
calculate_bwd: Calculate backward pass times
|
||||
dtype: Data type for tensors (bfloat16, float16, float32)
|
||||
b: Batch sizes to benchmark
|
||||
nh: Number of query and key/value heads in format "Hq,Hkv"
|
||||
s: Sequence lengths to benchmark
|
||||
d: Head dimensions to benchmark
|
||||
mods: Score modifications: noop, causal, rel, head_bias, alibi, sliding_window, document_mask, prefix_lm, softcap
|
||||
backend: Backends for attention computation: math, efficient, cudnn, fav2, fav3, fakv, og-eager
|
||||
max_autotune: Turn on max-autotune optimization
|
||||
decoding: Benchmark decoding mode (query sequence length = 1)
|
||||
kv_size: Key/value cache size in MiB (ignores batch size if specified)
|
||||
throughput: Calculate kernel memory bandwidth & computational throughput (always True)
|
||||
save_path: Path to save the results CSV file
|
||||
output_json_for_dashboard: Path to save results in JSON format for PyTorch OSS dashboard
|
||||
benchmark_name: Name of the benchmark for dashboard output
|
||||
"""
|
||||
# Convert dtype string to torch dtype (if not already converted)
|
||||
import torch
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = getattr(torch, dtype)
|
||||
|
||||
# Always calculate throughput
|
||||
throughput = True
|
||||
print("Backend: ", backend)
|
||||
seed = 123
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for config in tqdm(
|
||||
generate_experiment_configs(
|
||||
args.calculate_bwd,
|
||||
args.dtype,
|
||||
args.b,
|
||||
args.nh,
|
||||
args.s,
|
||||
args.d,
|
||||
args.mods,
|
||||
args.decoding,
|
||||
args.kv_size,
|
||||
args.throughput,
|
||||
args.backend,
|
||||
)
|
||||
for experiment_count, config in enumerate(
|
||||
tqdm(
|
||||
generate_experiment_configs(
|
||||
calculate_bwd,
|
||||
dtype,
|
||||
b,
|
||||
nh,
|
||||
s,
|
||||
d,
|
||||
mods,
|
||||
decoding,
|
||||
kv_size,
|
||||
throughput,
|
||||
backend,
|
||||
max_autotune,
|
||||
)
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
results.append(
|
||||
Experiment(
|
||||
config,
|
||||
run_single_experiment(
|
||||
config,
|
||||
dynamic=args.dynamic,
|
||||
max_autotune=args.max_autotune,
|
||||
dynamic=dynamic,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
print_results(results, args.save_path)
|
||||
# Periodic memory cleanup every 50 experiments
|
||||
if experiment_count % 50 == 0:
|
||||
cleanup_memory()
|
||||
|
||||
print_results(results, save_path)
|
||||
|
||||
def heads_input_type(s):
|
||||
try:
|
||||
hq, hkv = map(int, s.split(","))
|
||||
return hq, hkv
|
||||
except Exception as e:
|
||||
raise argparse.ArgumentTypeError("Heads must be Hq,Hkv") from e
|
||||
# Output JSON for dashboard if requested
|
||||
if output_json_for_dashboard:
|
||||
_output_json_for_dashboard(results, output_json_for_dashboard, benchmark_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -1130,6 +1501,12 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Run sweep over sizes and score mods for flex attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to JSON config file. CLI args override config file values.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dynamic",
|
||||
action="store_true",
|
||||
|
|
@ -1199,8 +1576,49 @@ Ignores -b batch size and calculate batch size from kv size instead when specifi
|
|||
default=["efficient"],
|
||||
help="Backend to use for attention computation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-json-for-dashboard",
|
||||
type=str,
|
||||
help="Path to save results in JSON format for PyTorch OSS dashboard",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark-name",
|
||||
type=str,
|
||||
help="Name of the benchmark for dashboard output",
|
||||
default="PyTorch operator microbenchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-config",
|
||||
type=str,
|
||||
choices=["json", "yaml"],
|
||||
help="Print a default config template in JSON or YAML format and exit",
|
||||
default=None,
|
||||
)
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
||||
main(args)
|
||||
# Handle --print-config
|
||||
if args.print_config:
|
||||
print_default_config(args.print_config)
|
||||
sys.exit(0)
|
||||
|
||||
# Load and merge config if provided
|
||||
if args.config:
|
||||
config = load_config_file(args.config)
|
||||
|
||||
# Merge config with CLI args (CLI args take precedence)
|
||||
json_args = argparse.Namespace()
|
||||
json_args.__dict__ = config
|
||||
args = parser.parse_args(namespace=json_args)
|
||||
|
||||
# Convert dtype string to torch dtype (only if it's still a string)
|
||||
if isinstance(args.dtype, str):
|
||||
args.dtype = getattr(torch, args.dtype)
|
||||
|
||||
# Remove config and print_config from args before passing to main
|
||||
args_dict = vars(args)
|
||||
args_dict.pop("config", None)
|
||||
args_dict.pop("print_config", None)
|
||||
|
||||
main(**args_dict)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user