[GPT-fast] Support run spcific model or micro-benchmark (#143607)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143607
Approved by: https://github.com/BoyuanFeng, https://github.com/jerryzh168, https://github.com/huydhn
This commit is contained in:
Yanbo Liang 2024-12-19 20:33:10 -08:00 committed by PyTorch MergeBot
parent 94737e8a2a
commit 792e6184c5
3 changed files with 52 additions and 54 deletions

View File

@ -4,12 +4,8 @@ import dataclasses
import json
import os
from generate import (
get_arch_name,
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
)
from common import all_experiments, Experiment, register_experiment
from generate import get_arch_name
import torch
import torch.nn as nn
@ -22,18 +18,6 @@ WARMUP_ITER = 5
A100_40G_BF16_TFLOPS = 312
@dataclasses.dataclass
class Experiment:
name: str
metric: str
target: float
actual: float
dtype: str
device: str
arch: str # GPU name for CUDA or CPU arch for CPU
is_model: bool = False
class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
super().__init__()
@ -52,6 +36,7 @@ class SimpleMLP(nn.Module):
return x
@register_experiment(name="mlp_layer_norm_gelu")
def run_mlp_layer_norm_gelu(device: str = "cuda"):
dtype_flops_utilization_map = {
torch.bfloat16: "0.8",
@ -102,6 +87,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"):
return results
@register_experiment(name="layer_norm")
def run_layer_norm(device: str = "cuda"):
dtype_memory_bandwidth_map = {
torch.bfloat16: "950",
@ -145,6 +131,7 @@ def run_layer_norm(device: str = "cuda"):
return results
@register_experiment(name="gather_gemv")
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gather_gemv(device: str = "cuda"):
E = 8
@ -194,6 +181,7 @@ def run_gather_gemv(device: str = "cuda"):
return results
@register_experiment(name="gemv")
@torch._inductor.config.patch(coordinate_descent_tuning=True)
def run_gemv(device: str = "cuda"):
dtype_memory_bandwidth_map = {
@ -297,30 +285,20 @@ def output_json(output_file, headers, row):
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
all_experiments = {
# A list of GPT models: LlaMa, Mixtral, etc.
# waiting for A100-80G machine to be available in CI
# https://github.com/pytorch/pytorch/actions/runs/12018005803/job/33503683582?pr=140627
# before we can turn on autoquant
# or alterantively, we can save the model after autoquant and just load here to track
# the performance
# run_llama2_7b_autoquant,
run_llama2_7b_bf16,
run_llama2_7b_int8,
run_mixtral_8x7b_int8,
# run_mixtral_8x7b_autoquant,
# A list of micro-benchmarks.
run_mlp_layer_norm_gelu,
run_layer_norm,
run_gather_gemv,
run_gemv,
}
def main(output_file=DEFAULT_OUTPUT_FILE):
def main(output_file=DEFAULT_OUTPUT_FILE, only_model=None):
results = []
for func in all_experiments:
if not only_model:
experiments = all_experiments.values()
else:
if only_model not in all_experiments:
print(
f"Unknown model: {only_model}, all available models: {all_experiments.keys()}"
)
# only run the specified model
experiments = [all_experiments[only_model]]
for func in experiments:
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
except AssertionError:
@ -347,6 +325,10 @@ if __name__ == "__main__":
default=DEFAULT_OUTPUT_FILE,
help="Set the output CSV file to save the benchmark results",
)
parser.add_argument(
"--only",
help="Specify a model or micro-benchmark name to run exclusively",
)
args = parser.parse_args()
main(output_file=args.output)
main(output_file=args.output, only_model=args.only)

View File

@ -0,0 +1,26 @@
import dataclasses
from typing import Callable, Dict, Optional
all_experiments: Dict[str, Callable] = {}
@dataclasses.dataclass
class Experiment:
name: str
metric: str
target: float
actual: float
dtype: str
device: str
arch: str # GPU name for CUDA or CPU arch for CPU
is_model: bool = False
def register_experiment(name: Optional[str] = None):
def decorator(func):
key = name or func.__name__
all_experiments[key] = func
return func
return decorator

View File

@ -5,6 +5,7 @@ import time
from typing import Optional, Tuple
import torchao
from common import Experiment, register_experiment
from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
from mixtral_moe_quantize import (
ConditionalFeedForwardInt8,
@ -295,9 +296,8 @@ def run_experiment(
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="llama2_7b_bf16")
def run_llama2_7b_bf16(device: str = "cuda"):
from benchmark import Experiment
model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
@ -345,9 +345,8 @@ def run_llama2_7b_bf16(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="llama2_7b_int8")
def run_llama2_7b_int8(device: str = "cuda"):
from benchmark import Experiment
model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
@ -395,9 +394,8 @@ def run_llama2_7b_int8(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
@register_experiment(name="mixtral_8x7b_int8")
def run_mixtral_8x7b_int8(device: str = "cuda"):
from benchmark import Experiment
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",
@ -447,8 +445,6 @@ def run_mixtral_8x7b_int8(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant(device: str = "cuda"):
from benchmark import Experiment
model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
@ -497,8 +493,6 @@ def run_llama2_7b_autoquant(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant(device: str = "cuda"):
from benchmark import Experiment
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",
@ -548,8 +542,6 @@ def run_mixtral_8x7b_autoquant(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_llama2_7b_autoquant_v2(device: str = "cuda"):
from benchmark import Experiment
model = GPTModelConfig(
"Llama-2-7b-chat-hf",
LLaMA,
@ -599,8 +591,6 @@ def run_llama2_7b_autoquant_v2(device: str = "cuda"):
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
def run_mixtral_8x7b_autoquant_v2(device: str = "cuda"):
from benchmark import Experiment
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
model = GPTModelConfig(
"Mixtral-8x7B-v0.1",