mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[GPT-fast] Support run spcific model or micro-benchmark (#143607)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143607 Approved by: https://github.com/BoyuanFeng, https://github.com/jerryzh168, https://github.com/huydhn
This commit is contained in:
parent
94737e8a2a
commit
792e6184c5
|
|
@ -4,12 +4,8 @@ import dataclasses
|
|||
import json
|
||||
import os
|
||||
|
||||
from generate import (
|
||||
get_arch_name,
|
||||
run_llama2_7b_bf16,
|
||||
run_llama2_7b_int8,
|
||||
run_mixtral_8x7b_int8,
|
||||
)
|
||||
from common import all_experiments, Experiment, register_experiment
|
||||
from generate import get_arch_name
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -22,18 +18,6 @@ WARMUP_ITER = 5
|
|||
A100_40G_BF16_TFLOPS = 312
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Experiment:
|
||||
name: str
|
||||
metric: str
|
||||
target: float
|
||||
actual: float
|
||||
dtype: str
|
||||
device: str
|
||||
arch: str # GPU name for CUDA or CPU arch for CPU
|
||||
is_model: bool = False
|
||||
|
||||
|
||||
class SimpleMLP(nn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim, dtype):
|
||||
super().__init__()
|
||||
|
|
@ -52,6 +36,7 @@ class SimpleMLP(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
@register_experiment(name="mlp_layer_norm_gelu")
|
||||
def run_mlp_layer_norm_gelu(device: str = "cuda"):
|
||||
dtype_flops_utilization_map = {
|
||||
torch.bfloat16: "0.8",
|
||||
|
|
@ -102,6 +87,7 @@ def run_mlp_layer_norm_gelu(device: str = "cuda"):
|
|||
return results
|
||||
|
||||
|
||||
@register_experiment(name="layer_norm")
|
||||
def run_layer_norm(device: str = "cuda"):
|
||||
dtype_memory_bandwidth_map = {
|
||||
torch.bfloat16: "950",
|
||||
|
|
@ -145,6 +131,7 @@ def run_layer_norm(device: str = "cuda"):
|
|||
return results
|
||||
|
||||
|
||||
@register_experiment(name="gather_gemv")
|
||||
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
||||
def run_gather_gemv(device: str = "cuda"):
|
||||
E = 8
|
||||
|
|
@ -194,6 +181,7 @@ def run_gather_gemv(device: str = "cuda"):
|
|||
return results
|
||||
|
||||
|
||||
@register_experiment(name="gemv")
|
||||
@torch._inductor.config.patch(coordinate_descent_tuning=True)
|
||||
def run_gemv(device: str = "cuda"):
|
||||
dtype_memory_bandwidth_map = {
|
||||
|
|
@ -297,30 +285,20 @@ def output_json(output_file, headers, row):
|
|||
|
||||
DEFAULT_OUTPUT_FILE = "gpt_fast_benchmark.csv"
|
||||
|
||||
all_experiments = {
|
||||
# A list of GPT models: LlaMa, Mixtral, etc.
|
||||
# waiting for A100-80G machine to be available in CI
|
||||
# https://github.com/pytorch/pytorch/actions/runs/12018005803/job/33503683582?pr=140627
|
||||
# before we can turn on autoquant
|
||||
# or alterantively, we can save the model after autoquant and just load here to track
|
||||
# the performance
|
||||
# run_llama2_7b_autoquant,
|
||||
run_llama2_7b_bf16,
|
||||
run_llama2_7b_int8,
|
||||
run_mixtral_8x7b_int8,
|
||||
# run_mixtral_8x7b_autoquant,
|
||||
# A list of micro-benchmarks.
|
||||
run_mlp_layer_norm_gelu,
|
||||
run_layer_norm,
|
||||
run_gather_gemv,
|
||||
run_gemv,
|
||||
}
|
||||
|
||||
|
||||
def main(output_file=DEFAULT_OUTPUT_FILE):
|
||||
def main(output_file=DEFAULT_OUTPUT_FILE, only_model=None):
|
||||
results = []
|
||||
|
||||
for func in all_experiments:
|
||||
if not only_model:
|
||||
experiments = all_experiments.values()
|
||||
else:
|
||||
if only_model not in all_experiments:
|
||||
print(
|
||||
f"Unknown model: {only_model}, all available models: {all_experiments.keys()}"
|
||||
)
|
||||
# only run the specified model
|
||||
experiments = [all_experiments[only_model]]
|
||||
for func in experiments:
|
||||
try:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
except AssertionError:
|
||||
|
|
@ -347,6 +325,10 @@ if __name__ == "__main__":
|
|||
default=DEFAULT_OUTPUT_FILE,
|
||||
help="Set the output CSV file to save the benchmark results",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only",
|
||||
help="Specify a model or micro-benchmark name to run exclusively",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(output_file=args.output)
|
||||
main(output_file=args.output, only_model=args.only)
|
||||
|
|
|
|||
26
benchmarks/gpt_fast/common.py
Normal file
26
benchmarks/gpt_fast/common.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import dataclasses
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
|
||||
all_experiments: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Experiment:
|
||||
name: str
|
||||
metric: str
|
||||
target: float
|
||||
actual: float
|
||||
dtype: str
|
||||
device: str
|
||||
arch: str # GPU name for CUDA or CPU arch for CPU
|
||||
is_model: bool = False
|
||||
|
||||
|
||||
def register_experiment(name: Optional[str] = None):
|
||||
def decorator(func):
|
||||
key = name or func.__name__
|
||||
all_experiments[key] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
|
@ -5,6 +5,7 @@ import time
|
|||
from typing import Optional, Tuple
|
||||
|
||||
import torchao
|
||||
from common import Experiment, register_experiment
|
||||
from mixtral_moe_model import ConditionalFeedForward, Transformer as MixtralMoE
|
||||
from mixtral_moe_quantize import (
|
||||
ConditionalFeedForwardInt8,
|
||||
|
|
@ -295,9 +296,8 @@ def run_experiment(
|
|||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
@register_experiment(name="llama2_7b_bf16")
|
||||
def run_llama2_7b_bf16(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
|
|
@ -345,9 +345,8 @@ def run_llama2_7b_bf16(device: str = "cuda"):
|
|||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
@register_experiment(name="llama2_7b_int8")
|
||||
def run_llama2_7b_int8(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
|
|
@ -395,9 +394,8 @@ def run_llama2_7b_int8(device: str = "cuda"):
|
|||
|
||||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
@register_experiment(name="mixtral_8x7b_int8")
|
||||
def run_mixtral_8x7b_int8(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
|
||||
model = GPTModelConfig(
|
||||
"Mixtral-8x7B-v0.1",
|
||||
|
|
@ -447,8 +445,6 @@ def run_mixtral_8x7b_int8(device: str = "cuda"):
|
|||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_llama2_7b_autoquant(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
|
|
@ -497,8 +493,6 @@ def run_llama2_7b_autoquant(device: str = "cuda"):
|
|||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_mixtral_8x7b_autoquant(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
|
||||
model = GPTModelConfig(
|
||||
"Mixtral-8x7B-v0.1",
|
||||
|
|
@ -548,8 +542,6 @@ def run_mixtral_8x7b_autoquant(device: str = "cuda"):
|
|||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_llama2_7b_autoquant_v2(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
model = GPTModelConfig(
|
||||
"Llama-2-7b-chat-hf",
|
||||
LLaMA,
|
||||
|
|
@ -599,8 +591,6 @@ def run_llama2_7b_autoquant_v2(device: str = "cuda"):
|
|||
|
||||
# token_per_sec and memory_bandwidth target numbers are for A100-40GB, which are different from the typical A100-80GB.
|
||||
def run_mixtral_8x7b_autoquant_v2(device: str = "cuda"):
|
||||
from benchmark import Experiment
|
||||
|
||||
# We reduced the original number of layers from 32 to 16 to adapt CI memory limitation.
|
||||
model = GPTModelConfig(
|
||||
"Mixtral-8x7B-v0.1",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user