mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add blas compare example (#47058)
Summary: Adds a standalone script which can be used to test different BLAS libraries. Right now I've deliberately kept it limited (only a couple BLAS libs and only GEMM and GEMV). It's easy enough to expand later. CC ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/47058 Reviewed By: zhangguanheng66 Differential Revision: D25078946 Pulled By: robieta fbshipit-source-id: b5f7f7ec289d59c16c5370b7a6636c10a496b3ac
This commit is contained in:
parent
008f840e7a
commit
678fe9f077
3
mypy.ini
3
mypy.ini
|
|
@ -137,6 +137,9 @@ ignore_errors = True
|
|||
[mypy-torch.utils.hipify.hipify_python]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.benchmark.examples.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.quantized.modules.batchnorm]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
230
torch/utils/benchmark/examples/blas_compare.py
Normal file
230
torch/utils/benchmark/examples/blas_compare.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
import argparse
|
||||
import datetime
|
||||
import itertools as it
|
||||
import multiprocessing
|
||||
import multiprocessing.dummy
|
||||
import os
|
||||
import queue
|
||||
import pickle
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
|
||||
import blas_compare_setup
|
||||
|
||||
|
||||
MIN_RUN_TIME = 1
|
||||
NUM_REPLICATES = 20
|
||||
NUM_THREAD_SETTINGS = (1, 2, 4)
|
||||
RESULT_FILE = os.path.join(blas_compare_setup.WORKING_ROOT, "blas_results.pkl")
|
||||
SCRATCH_DIR = os.path.join(blas_compare_setup.WORKING_ROOT, "scratch")
|
||||
|
||||
|
||||
BLAS_CONFIGS = (
|
||||
("MKL (2020.3)", blas_compare_setup.MKL_2020_3, None),
|
||||
("MKL (2020.0)", blas_compare_setup.MKL_2020_0, None),
|
||||
("OpenBLAS", blas_compare_setup.OPEN_BLAS, None)
|
||||
)
|
||||
|
||||
|
||||
_RESULT_FILE_LOCK = threading.Lock()
|
||||
_WORKER_POOL = queue.Queue()
|
||||
def clear_worker_pool():
|
||||
while not _WORKER_POOL.empty():
|
||||
_, result_file, _ = _WORKER_POOL.get_nowait()
|
||||
os.remove(result_file)
|
||||
|
||||
if os.path.exists(SCRATCH_DIR):
|
||||
shutil.rmtree(SCRATCH_DIR)
|
||||
|
||||
|
||||
def fill_core_pool(n: int):
|
||||
clear_worker_pool()
|
||||
os.makedirs(SCRATCH_DIR)
|
||||
|
||||
# Reserve two cores so that bookkeeping does not interfere with runs.
|
||||
cpu_count = multiprocessing.cpu_count() - 2
|
||||
|
||||
# Adjacent cores sometimes share cache, so we space out single core runs.
|
||||
step = max(n, 2)
|
||||
for i in range(0, cpu_count, step):
|
||||
core_str = f"{i}" if n == 1 else f"{i},{i + n - 1}"
|
||||
_, result_file = tempfile.mkstemp(suffix=".pkl", prefix=SCRATCH_DIR)
|
||||
_WORKER_POOL.put((core_str, result_file, n))
|
||||
|
||||
|
||||
def _subprocess_main(seed=0, num_threads=1, sub_label="N/A", result_file=None, env=None):
|
||||
import torch
|
||||
from torch.utils.benchmark import Timer
|
||||
|
||||
conda_prefix = os.getenv("CONDA_PREFIX")
|
||||
assert conda_prefix
|
||||
if not torch.__file__.startswith(conda_prefix):
|
||||
raise ValueError(
|
||||
f"PyTorch mismatch: `import torch` resolved to `{torch.__file__}`, "
|
||||
f"which is not in the correct conda env: {conda_prefix}"
|
||||
)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
results = []
|
||||
for n in [4, 8, 16, 32, 64, 128, 256, 512, 1024, 7, 96, 150, 225]:
|
||||
dtypes = (("Single", torch.float32), ("Double", torch.float64))
|
||||
shapes = (
|
||||
# Square MatMul
|
||||
((n, n), (n, n), "(n x n) x (n x n)", "Matrix-Matrix Product"),
|
||||
|
||||
# Matrix-Vector product
|
||||
((n, n), (n, 1), "(n x n) x (n x 1)", "Matrix-Vector Product"),
|
||||
)
|
||||
for (dtype_name, dtype), (x_shape, y_shape, shape_str, blas_type) in it.product(dtypes, shapes):
|
||||
t = Timer(
|
||||
stmt="torch.mm(x, y)",
|
||||
label=f"torch.mm {shape_str} {blas_type} ({dtype_name})",
|
||||
sub_label=sub_label,
|
||||
description=f"n = {n}",
|
||||
env=os.path.split(env or "")[1] or None,
|
||||
globals={
|
||||
"x": torch.rand(x_shape, dtype=dtype),
|
||||
"y": torch.rand(y_shape, dtype=dtype),
|
||||
},
|
||||
num_threads=num_threads,
|
||||
).blocked_autorange(min_run_time=MIN_RUN_TIME)
|
||||
results.append(t)
|
||||
|
||||
if result_file is not None:
|
||||
with open(result_file, "wb") as f:
|
||||
pickle.dump(results, f)
|
||||
|
||||
|
||||
def run_subprocess(args):
|
||||
seed, env, sub_label, extra_env_vars = args
|
||||
core_str = None
|
||||
try:
|
||||
core_str, result_file, num_threads = _WORKER_POOL.get()
|
||||
with open(result_file, "wb"):
|
||||
pass
|
||||
|
||||
env_vars = {
|
||||
"PATH": os.getenv("PATH"),
|
||||
"PYTHONPATH": os.getenv("PYTHONPATH") or "",
|
||||
|
||||
# NumPy
|
||||
"OMP_NUM_THREADS": str(num_threads),
|
||||
"MKL_NUM_THREADS": str(num_threads),
|
||||
"NUMEXPR_NUM_THREADS": str(num_threads),
|
||||
}
|
||||
env_vars.update(extra_env_vars or {})
|
||||
|
||||
subprocess.run(
|
||||
f"source activate {env} && "
|
||||
f"taskset --cpu-list {core_str} "
|
||||
f"python {os.path.abspath(__file__)} "
|
||||
"--DETAIL_in_subprocess "
|
||||
f"--DETAIL_seed {seed} "
|
||||
f"--DETAIL_num_threads {num_threads} "
|
||||
f"--DETAIL_sub_label '{sub_label}' "
|
||||
f"--DETAIL_result_file {result_file} "
|
||||
f"--DETAIL_env {env}",
|
||||
env=env_vars,
|
||||
stdout=subprocess.PIPE,
|
||||
shell=True
|
||||
)
|
||||
|
||||
with open(result_file, "rb") as f:
|
||||
result_bytes = f.read()
|
||||
|
||||
with _RESULT_FILE_LOCK, \
|
||||
open(RESULT_FILE, "ab") as f:
|
||||
f.write(result_bytes)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass # Handle ctrl-c gracefully.
|
||||
|
||||
finally:
|
||||
if core_str is not None:
|
||||
_WORKER_POOL.put((core_str, result_file, num_threads))
|
||||
|
||||
|
||||
def _compare_main():
|
||||
results = []
|
||||
with open(RESULT_FILE, "rb") as f:
|
||||
while True:
|
||||
try:
|
||||
results.extend(pickle.load(f))
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
from torch.utils.benchmark import Compare
|
||||
|
||||
comparison = Compare(results)
|
||||
comparison.trim_significant_figures()
|
||||
comparison.colorize()
|
||||
comparison.print()
|
||||
|
||||
|
||||
def main():
|
||||
with open(RESULT_FILE, "wb"):
|
||||
pass
|
||||
|
||||
for num_threads in NUM_THREAD_SETTINGS:
|
||||
fill_core_pool(num_threads)
|
||||
workers = _WORKER_POOL.qsize()
|
||||
|
||||
trials = []
|
||||
for seed in range(NUM_REPLICATES):
|
||||
for sub_label, env, extra_env_vars in BLAS_CONFIGS:
|
||||
env_path = os.path.join(blas_compare_setup.WORKING_ROOT, env)
|
||||
trials.append((seed, env_path, sub_label, extra_env_vars))
|
||||
|
||||
n = len(trials)
|
||||
with multiprocessing.dummy.Pool(workers) as pool:
|
||||
start_time = time.time()
|
||||
for i, r in enumerate(pool.imap(run_subprocess, trials)):
|
||||
n_trials_done = i + 1
|
||||
time_per_result = (time.time() - start_time) / n_trials_done
|
||||
eta = int((n - n_trials_done) * time_per_result)
|
||||
print(f"\r{i + 1} / {n} ETA:{datetime.timedelta(seconds=eta)}".ljust(80), end="")
|
||||
sys.stdout.flush()
|
||||
print(f"\r{n} / {n} Total time: {datetime.timedelta(seconds=int(time.time() - start_time))}")
|
||||
print()
|
||||
|
||||
# Any env will do, it just needs to have torch for benchmark utils.
|
||||
env_path = os.path.join(blas_compare_setup.WORKING_ROOT, BLAS_CONFIGS[0][1])
|
||||
subprocess.run(
|
||||
f"source activate {env_path} && "
|
||||
f"python {os.path.abspath(__file__)} "
|
||||
"--DETAIL_in_compare",
|
||||
shell=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# These flags are for subprocess control, not controlling the main loop.
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--DETAIL_in_subprocess", action="store_true")
|
||||
parser.add_argument("--DETAIL_in_compare", action="store_true")
|
||||
parser.add_argument("--DETAIL_seed", type=int, default=None)
|
||||
parser.add_argument("--DETAIL_num_threads", type=int, default=None)
|
||||
parser.add_argument("--DETAIL_sub_label", type=str, default="N/A")
|
||||
parser.add_argument("--DETAIL_result_file", type=str, default=None)
|
||||
parser.add_argument("--DETAIL_env", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.DETAIL_in_subprocess:
|
||||
try:
|
||||
_subprocess_main(
|
||||
args.DETAIL_seed,
|
||||
args.DETAIL_num_threads,
|
||||
args.DETAIL_sub_label,
|
||||
args.DETAIL_result_file,
|
||||
args.DETAIL_env,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass # Handle ctrl-c gracefully.
|
||||
elif args.DETAIL_in_compare:
|
||||
_compare_main()
|
||||
else:
|
||||
main()
|
||||
221
torch/utils/benchmark/examples/blas_compare_setup.py
Normal file
221
torch/utils/benchmark/examples/blas_compare_setup.py
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
import conda.cli.python_api
|
||||
from conda.cli.python_api import Commands as conda_commands
|
||||
except ImportError:
|
||||
# blas_compare.py will fail to import these when it's inside a conda env,
|
||||
# but that's fine as it only wants the constants.
|
||||
pass
|
||||
|
||||
|
||||
WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
|
||||
MKL_2020_3 = "mkl_2020_3"
|
||||
MKL_2020_0 = "mkl_2020_0"
|
||||
OPEN_BLAS = "open_blas"
|
||||
EIGEN = "eigen"
|
||||
|
||||
|
||||
GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
|
||||
BASE_PKG_DEPS = (
|
||||
"cffi",
|
||||
"cmake",
|
||||
"hypothesis",
|
||||
"ninja",
|
||||
"numpy",
|
||||
"pyyaml",
|
||||
"setuptools",
|
||||
"typing_extensions",
|
||||
)
|
||||
|
||||
|
||||
SubEnvSpec = collections.namedtuple(
|
||||
"SubEnvSpec", (
|
||||
"generic_installs",
|
||||
"special_installs",
|
||||
"environment_variables",
|
||||
|
||||
# Validate install.
|
||||
"expected_blas_symbols",
|
||||
"expected_mkl_version",
|
||||
))
|
||||
|
||||
|
||||
SUB_ENVS = {
|
||||
MKL_2020_3: SubEnvSpec(
|
||||
generic_installs=(),
|
||||
special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
|
||||
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
|
||||
expected_blas_symbols=("mkl_blas_sgemm",),
|
||||
expected_mkl_version="2020.0.3",
|
||||
),
|
||||
|
||||
MKL_2020_0: SubEnvSpec(
|
||||
generic_installs=(),
|
||||
special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
|
||||
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
|
||||
expected_blas_symbols=("mkl_blas_sgemm",),
|
||||
expected_mkl_version="2020.0.0",
|
||||
),
|
||||
|
||||
OPEN_BLAS: SubEnvSpec(
|
||||
generic_installs=("openblas",),
|
||||
special_installs=(),
|
||||
environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
|
||||
expected_blas_symbols=("exec_blas",),
|
||||
expected_mkl_version=None,
|
||||
),
|
||||
|
||||
# EIGEN: SubEnvSpec(
|
||||
# generic_installs=(),
|
||||
# special_installs=(),
|
||||
# environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
|
||||
# expected_blas_symbols=(),
|
||||
# ),
|
||||
}
|
||||
|
||||
|
||||
def conda_run(*args):
|
||||
"""Convenience method."""
|
||||
stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
|
||||
if retcode:
|
||||
raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}")
|
||||
|
||||
return stdout
|
||||
|
||||
|
||||
def main():
|
||||
if os.path.exists(WORKING_ROOT):
|
||||
print("Cleaning: removing old working root.")
|
||||
shutil.rmtree(WORKING_ROOT)
|
||||
os.makedirs(WORKING_ROOT)
|
||||
|
||||
git_root = subprocess.check_output(
|
||||
"git rev-parse --show-toplevel",
|
||||
shell=True,
|
||||
cwd=os.path.dirname(os.path.realpath(__file__))
|
||||
).decode("utf-8").strip()
|
||||
|
||||
for env_name, env_spec in SUB_ENVS.items():
|
||||
env_path = os.path.join(WORKING_ROOT, env_name)
|
||||
print(f"Creating env: {env_name}: ({env_path})")
|
||||
conda_run(
|
||||
conda_commands.CREATE,
|
||||
"--no-default-packages",
|
||||
"--prefix", env_path,
|
||||
"python=3",
|
||||
)
|
||||
|
||||
print("Testing that env can be activated:")
|
||||
base_source = subprocess.run(
|
||||
f"source activate {env_path}",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if base_source.returncode:
|
||||
raise OSError(
|
||||
"Failed to source base environment:\n"
|
||||
f" stdout: {base_source.stdout.decode('utf-8')}\n"
|
||||
f" stderr: {base_source.stderr.decode('utf-8')}"
|
||||
)
|
||||
|
||||
print("Installing packages:")
|
||||
conda_run(
|
||||
conda_commands.INSTALL,
|
||||
"--prefix", env_path,
|
||||
*(BASE_PKG_DEPS + env_spec.generic_installs)
|
||||
)
|
||||
|
||||
if env_spec.special_installs:
|
||||
channel, channel_deps = env_spec.special_installs
|
||||
print(f"Installing packages from channel: {channel}")
|
||||
conda_run(
|
||||
conda_commands.INSTALL,
|
||||
"--prefix", env_path,
|
||||
"-c", channel, *channel_deps
|
||||
)
|
||||
|
||||
if env_spec.environment_variables:
|
||||
print("Setting environment variables.")
|
||||
|
||||
# This does not appear to be possible using the python API.
|
||||
env_set = subprocess.run(
|
||||
f"source activate {env_path} && "
|
||||
f"conda env config vars set {' '.join(env_spec.environment_variables)}",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if env_set.returncode:
|
||||
raise OSError(
|
||||
"Failed to set environment variables:\n"
|
||||
f" stdout: {env_set.stdout.decode('utf-8')}\n"
|
||||
f" stderr: {env_set.stderr.decode('utf-8')}"
|
||||
)
|
||||
|
||||
# Check that they were actually set correctly.
|
||||
actual_env_vars = subprocess.run(
|
||||
f"source activate {env_path} && env",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
).stdout.decode("utf-8").strip().splitlines()
|
||||
for e in env_spec.environment_variables:
|
||||
assert e in actual_env_vars, f"{e} not in envs"
|
||||
|
||||
print(f"Building PyTorch for env: `{env_name}`")
|
||||
# We have to re-run during each build to pick up the new
|
||||
# build config settings.
|
||||
build_run = subprocess.run(
|
||||
f"source activate {env_path} && "
|
||||
f"cd {git_root} && "
|
||||
"python setup.py install --cmake",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
print("Checking configuration:")
|
||||
check_run = subprocess.run(
|
||||
# Shameless abuse of `python -c ...`
|
||||
f"source activate {env_path} && "
|
||||
"python -c \""
|
||||
"import torch;"
|
||||
"from torch.utils.benchmark import Timer;"
|
||||
"print(torch.__config__.show());"
|
||||
"setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
|
||||
"counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
|
||||
"stats = counts.as_standardized().stats(inclusive=True);"
|
||||
"print(stats.filter(lambda l: 'blas' in l.lower()))\"",
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
if check_run.returncode:
|
||||
raise OSError(
|
||||
"Failed to set environment variables:\n"
|
||||
f" stdout: {check_run.stdout.decode('utf-8')}\n"
|
||||
f" stderr: {check_run.stderr.decode('utf-8')}"
|
||||
)
|
||||
check_run_stdout = check_run.stdout.decode('utf-8')
|
||||
print(check_run_stdout)
|
||||
|
||||
for e in env_spec.environment_variables:
|
||||
if "BLAS" in e:
|
||||
assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
|
||||
|
||||
for s in env_spec.expected_blas_symbols:
|
||||
assert s in check_run_stdout
|
||||
|
||||
if env_spec.expected_mkl_version is not None:
|
||||
assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
|
||||
|
||||
print(f"Build complete: {env_name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user