mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add blas compare example (#47058)
Summary: Adds a standalone script which can be used to test different BLAS libraries. Right now I've deliberately kept it limited (only a couple BLAS libs and only GEMM and GEMV). It's easy enough to expand later. CC ngimel Pull Request resolved: https://github.com/pytorch/pytorch/pull/47058 Reviewed By: zhangguanheng66 Differential Revision: D25078946 Pulled By: robieta fbshipit-source-id: b5f7f7ec289d59c16c5370b7a6636c10a496b3ac
This commit is contained in:
parent
008f840e7a
commit
678fe9f077
3
mypy.ini
3
mypy.ini
|
|
@ -137,6 +137,9 @@ ignore_errors = True
|
||||||
[mypy-torch.utils.hipify.hipify_python]
|
[mypy-torch.utils.hipify.hipify_python]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-torch.utils.benchmark.examples.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-torch.nn.quantized.modules.batchnorm]
|
[mypy-torch.nn.quantized.modules.batchnorm]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|
|
||||||
230
torch/utils/benchmark/examples/blas_compare.py
Normal file
230
torch/utils/benchmark/examples/blas_compare.py
Normal file
|
|
@ -0,0 +1,230 @@
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import itertools as it
|
||||||
|
import multiprocessing
|
||||||
|
import multiprocessing.dummy
|
||||||
|
import os
|
||||||
|
import queue
|
||||||
|
import pickle
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import blas_compare_setup
|
||||||
|
|
||||||
|
|
||||||
|
MIN_RUN_TIME = 1
|
||||||
|
NUM_REPLICATES = 20
|
||||||
|
NUM_THREAD_SETTINGS = (1, 2, 4)
|
||||||
|
RESULT_FILE = os.path.join(blas_compare_setup.WORKING_ROOT, "blas_results.pkl")
|
||||||
|
SCRATCH_DIR = os.path.join(blas_compare_setup.WORKING_ROOT, "scratch")
|
||||||
|
|
||||||
|
|
||||||
|
BLAS_CONFIGS = (
|
||||||
|
("MKL (2020.3)", blas_compare_setup.MKL_2020_3, None),
|
||||||
|
("MKL (2020.0)", blas_compare_setup.MKL_2020_0, None),
|
||||||
|
("OpenBLAS", blas_compare_setup.OPEN_BLAS, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_RESULT_FILE_LOCK = threading.Lock()
|
||||||
|
_WORKER_POOL = queue.Queue()
|
||||||
|
def clear_worker_pool():
|
||||||
|
while not _WORKER_POOL.empty():
|
||||||
|
_, result_file, _ = _WORKER_POOL.get_nowait()
|
||||||
|
os.remove(result_file)
|
||||||
|
|
||||||
|
if os.path.exists(SCRATCH_DIR):
|
||||||
|
shutil.rmtree(SCRATCH_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def fill_core_pool(n: int):
|
||||||
|
clear_worker_pool()
|
||||||
|
os.makedirs(SCRATCH_DIR)
|
||||||
|
|
||||||
|
# Reserve two cores so that bookkeeping does not interfere with runs.
|
||||||
|
cpu_count = multiprocessing.cpu_count() - 2
|
||||||
|
|
||||||
|
# Adjacent cores sometimes share cache, so we space out single core runs.
|
||||||
|
step = max(n, 2)
|
||||||
|
for i in range(0, cpu_count, step):
|
||||||
|
core_str = f"{i}" if n == 1 else f"{i},{i + n - 1}"
|
||||||
|
_, result_file = tempfile.mkstemp(suffix=".pkl", prefix=SCRATCH_DIR)
|
||||||
|
_WORKER_POOL.put((core_str, result_file, n))
|
||||||
|
|
||||||
|
|
||||||
|
def _subprocess_main(seed=0, num_threads=1, sub_label="N/A", result_file=None, env=None):
|
||||||
|
import torch
|
||||||
|
from torch.utils.benchmark import Timer
|
||||||
|
|
||||||
|
conda_prefix = os.getenv("CONDA_PREFIX")
|
||||||
|
assert conda_prefix
|
||||||
|
if not torch.__file__.startswith(conda_prefix):
|
||||||
|
raise ValueError(
|
||||||
|
f"PyTorch mismatch: `import torch` resolved to `{torch.__file__}`, "
|
||||||
|
f"which is not in the correct conda env: {conda_prefix}"
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
results = []
|
||||||
|
for n in [4, 8, 16, 32, 64, 128, 256, 512, 1024, 7, 96, 150, 225]:
|
||||||
|
dtypes = (("Single", torch.float32), ("Double", torch.float64))
|
||||||
|
shapes = (
|
||||||
|
# Square MatMul
|
||||||
|
((n, n), (n, n), "(n x n) x (n x n)", "Matrix-Matrix Product"),
|
||||||
|
|
||||||
|
# Matrix-Vector product
|
||||||
|
((n, n), (n, 1), "(n x n) x (n x 1)", "Matrix-Vector Product"),
|
||||||
|
)
|
||||||
|
for (dtype_name, dtype), (x_shape, y_shape, shape_str, blas_type) in it.product(dtypes, shapes):
|
||||||
|
t = Timer(
|
||||||
|
stmt="torch.mm(x, y)",
|
||||||
|
label=f"torch.mm {shape_str} {blas_type} ({dtype_name})",
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=f"n = {n}",
|
||||||
|
env=os.path.split(env or "")[1] or None,
|
||||||
|
globals={
|
||||||
|
"x": torch.rand(x_shape, dtype=dtype),
|
||||||
|
"y": torch.rand(y_shape, dtype=dtype),
|
||||||
|
},
|
||||||
|
num_threads=num_threads,
|
||||||
|
).blocked_autorange(min_run_time=MIN_RUN_TIME)
|
||||||
|
results.append(t)
|
||||||
|
|
||||||
|
if result_file is not None:
|
||||||
|
with open(result_file, "wb") as f:
|
||||||
|
pickle.dump(results, f)
|
||||||
|
|
||||||
|
|
||||||
|
def run_subprocess(args):
|
||||||
|
seed, env, sub_label, extra_env_vars = args
|
||||||
|
core_str = None
|
||||||
|
try:
|
||||||
|
core_str, result_file, num_threads = _WORKER_POOL.get()
|
||||||
|
with open(result_file, "wb"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
env_vars = {
|
||||||
|
"PATH": os.getenv("PATH"),
|
||||||
|
"PYTHONPATH": os.getenv("PYTHONPATH") or "",
|
||||||
|
|
||||||
|
# NumPy
|
||||||
|
"OMP_NUM_THREADS": str(num_threads),
|
||||||
|
"MKL_NUM_THREADS": str(num_threads),
|
||||||
|
"NUMEXPR_NUM_THREADS": str(num_threads),
|
||||||
|
}
|
||||||
|
env_vars.update(extra_env_vars or {})
|
||||||
|
|
||||||
|
subprocess.run(
|
||||||
|
f"source activate {env} && "
|
||||||
|
f"taskset --cpu-list {core_str} "
|
||||||
|
f"python {os.path.abspath(__file__)} "
|
||||||
|
"--DETAIL_in_subprocess "
|
||||||
|
f"--DETAIL_seed {seed} "
|
||||||
|
f"--DETAIL_num_threads {num_threads} "
|
||||||
|
f"--DETAIL_sub_label '{sub_label}' "
|
||||||
|
f"--DETAIL_result_file {result_file} "
|
||||||
|
f"--DETAIL_env {env}",
|
||||||
|
env=env_vars,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
shell=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(result_file, "rb") as f:
|
||||||
|
result_bytes = f.read()
|
||||||
|
|
||||||
|
with _RESULT_FILE_LOCK, \
|
||||||
|
open(RESULT_FILE, "ab") as f:
|
||||||
|
f.write(result_bytes)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass # Handle ctrl-c gracefully.
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if core_str is not None:
|
||||||
|
_WORKER_POOL.put((core_str, result_file, num_threads))
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_main():
|
||||||
|
results = []
|
||||||
|
with open(RESULT_FILE, "rb") as f:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
results.extend(pickle.load(f))
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
|
||||||
|
from torch.utils.benchmark import Compare
|
||||||
|
|
||||||
|
comparison = Compare(results)
|
||||||
|
comparison.trim_significant_figures()
|
||||||
|
comparison.colorize()
|
||||||
|
comparison.print()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
with open(RESULT_FILE, "wb"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
for num_threads in NUM_THREAD_SETTINGS:
|
||||||
|
fill_core_pool(num_threads)
|
||||||
|
workers = _WORKER_POOL.qsize()
|
||||||
|
|
||||||
|
trials = []
|
||||||
|
for seed in range(NUM_REPLICATES):
|
||||||
|
for sub_label, env, extra_env_vars in BLAS_CONFIGS:
|
||||||
|
env_path = os.path.join(blas_compare_setup.WORKING_ROOT, env)
|
||||||
|
trials.append((seed, env_path, sub_label, extra_env_vars))
|
||||||
|
|
||||||
|
n = len(trials)
|
||||||
|
with multiprocessing.dummy.Pool(workers) as pool:
|
||||||
|
start_time = time.time()
|
||||||
|
for i, r in enumerate(pool.imap(run_subprocess, trials)):
|
||||||
|
n_trials_done = i + 1
|
||||||
|
time_per_result = (time.time() - start_time) / n_trials_done
|
||||||
|
eta = int((n - n_trials_done) * time_per_result)
|
||||||
|
print(f"\r{i + 1} / {n} ETA:{datetime.timedelta(seconds=eta)}".ljust(80), end="")
|
||||||
|
sys.stdout.flush()
|
||||||
|
print(f"\r{n} / {n} Total time: {datetime.timedelta(seconds=int(time.time() - start_time))}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Any env will do, it just needs to have torch for benchmark utils.
|
||||||
|
env_path = os.path.join(blas_compare_setup.WORKING_ROOT, BLAS_CONFIGS[0][1])
|
||||||
|
subprocess.run(
|
||||||
|
f"source activate {env_path} && "
|
||||||
|
f"python {os.path.abspath(__file__)} "
|
||||||
|
"--DETAIL_in_compare",
|
||||||
|
shell=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# These flags are for subprocess control, not controlling the main loop.
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--DETAIL_in_subprocess", action="store_true")
|
||||||
|
parser.add_argument("--DETAIL_in_compare", action="store_true")
|
||||||
|
parser.add_argument("--DETAIL_seed", type=int, default=None)
|
||||||
|
parser.add_argument("--DETAIL_num_threads", type=int, default=None)
|
||||||
|
parser.add_argument("--DETAIL_sub_label", type=str, default="N/A")
|
||||||
|
parser.add_argument("--DETAIL_result_file", type=str, default=None)
|
||||||
|
parser.add_argument("--DETAIL_env", type=str, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.DETAIL_in_subprocess:
|
||||||
|
try:
|
||||||
|
_subprocess_main(
|
||||||
|
args.DETAIL_seed,
|
||||||
|
args.DETAIL_num_threads,
|
||||||
|
args.DETAIL_sub_label,
|
||||||
|
args.DETAIL_result_file,
|
||||||
|
args.DETAIL_env,
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass # Handle ctrl-c gracefully.
|
||||||
|
elif args.DETAIL_in_compare:
|
||||||
|
_compare_main()
|
||||||
|
else:
|
||||||
|
main()
|
||||||
221
torch/utils/benchmark/examples/blas_compare_setup.py
Normal file
221
torch/utils/benchmark/examples/blas_compare_setup.py
Normal file
|
|
@ -0,0 +1,221 @@
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
import conda.cli.python_api
|
||||||
|
from conda.cli.python_api import Commands as conda_commands
|
||||||
|
except ImportError:
|
||||||
|
# blas_compare.py will fail to import these when it's inside a conda env,
|
||||||
|
# but that's fine as it only wants the constants.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
|
||||||
|
MKL_2020_3 = "mkl_2020_3"
|
||||||
|
MKL_2020_0 = "mkl_2020_0"
|
||||||
|
OPEN_BLAS = "open_blas"
|
||||||
|
EIGEN = "eigen"
|
||||||
|
|
||||||
|
|
||||||
|
GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
|
||||||
|
BASE_PKG_DEPS = (
|
||||||
|
"cffi",
|
||||||
|
"cmake",
|
||||||
|
"hypothesis",
|
||||||
|
"ninja",
|
||||||
|
"numpy",
|
||||||
|
"pyyaml",
|
||||||
|
"setuptools",
|
||||||
|
"typing_extensions",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SubEnvSpec = collections.namedtuple(
|
||||||
|
"SubEnvSpec", (
|
||||||
|
"generic_installs",
|
||||||
|
"special_installs",
|
||||||
|
"environment_variables",
|
||||||
|
|
||||||
|
# Validate install.
|
||||||
|
"expected_blas_symbols",
|
||||||
|
"expected_mkl_version",
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
SUB_ENVS = {
|
||||||
|
MKL_2020_3: SubEnvSpec(
|
||||||
|
generic_installs=(),
|
||||||
|
special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
|
||||||
|
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
|
||||||
|
expected_blas_symbols=("mkl_blas_sgemm",),
|
||||||
|
expected_mkl_version="2020.0.3",
|
||||||
|
),
|
||||||
|
|
||||||
|
MKL_2020_0: SubEnvSpec(
|
||||||
|
generic_installs=(),
|
||||||
|
special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
|
||||||
|
environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
|
||||||
|
expected_blas_symbols=("mkl_blas_sgemm",),
|
||||||
|
expected_mkl_version="2020.0.0",
|
||||||
|
),
|
||||||
|
|
||||||
|
OPEN_BLAS: SubEnvSpec(
|
||||||
|
generic_installs=("openblas",),
|
||||||
|
special_installs=(),
|
||||||
|
environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
|
||||||
|
expected_blas_symbols=("exec_blas",),
|
||||||
|
expected_mkl_version=None,
|
||||||
|
),
|
||||||
|
|
||||||
|
# EIGEN: SubEnvSpec(
|
||||||
|
# generic_installs=(),
|
||||||
|
# special_installs=(),
|
||||||
|
# environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
|
||||||
|
# expected_blas_symbols=(),
|
||||||
|
# ),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def conda_run(*args):
|
||||||
|
"""Convenience method."""
|
||||||
|
stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
|
||||||
|
if retcode:
|
||||||
|
raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}")
|
||||||
|
|
||||||
|
return stdout
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
if os.path.exists(WORKING_ROOT):
|
||||||
|
print("Cleaning: removing old working root.")
|
||||||
|
shutil.rmtree(WORKING_ROOT)
|
||||||
|
os.makedirs(WORKING_ROOT)
|
||||||
|
|
||||||
|
git_root = subprocess.check_output(
|
||||||
|
"git rev-parse --show-toplevel",
|
||||||
|
shell=True,
|
||||||
|
cwd=os.path.dirname(os.path.realpath(__file__))
|
||||||
|
).decode("utf-8").strip()
|
||||||
|
|
||||||
|
for env_name, env_spec in SUB_ENVS.items():
|
||||||
|
env_path = os.path.join(WORKING_ROOT, env_name)
|
||||||
|
print(f"Creating env: {env_name}: ({env_path})")
|
||||||
|
conda_run(
|
||||||
|
conda_commands.CREATE,
|
||||||
|
"--no-default-packages",
|
||||||
|
"--prefix", env_path,
|
||||||
|
"python=3",
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Testing that env can be activated:")
|
||||||
|
base_source = subprocess.run(
|
||||||
|
f"source activate {env_path}",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
if base_source.returncode:
|
||||||
|
raise OSError(
|
||||||
|
"Failed to source base environment:\n"
|
||||||
|
f" stdout: {base_source.stdout.decode('utf-8')}\n"
|
||||||
|
f" stderr: {base_source.stderr.decode('utf-8')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Installing packages:")
|
||||||
|
conda_run(
|
||||||
|
conda_commands.INSTALL,
|
||||||
|
"--prefix", env_path,
|
||||||
|
*(BASE_PKG_DEPS + env_spec.generic_installs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if env_spec.special_installs:
|
||||||
|
channel, channel_deps = env_spec.special_installs
|
||||||
|
print(f"Installing packages from channel: {channel}")
|
||||||
|
conda_run(
|
||||||
|
conda_commands.INSTALL,
|
||||||
|
"--prefix", env_path,
|
||||||
|
"-c", channel, *channel_deps
|
||||||
|
)
|
||||||
|
|
||||||
|
if env_spec.environment_variables:
|
||||||
|
print("Setting environment variables.")
|
||||||
|
|
||||||
|
# This does not appear to be possible using the python API.
|
||||||
|
env_set = subprocess.run(
|
||||||
|
f"source activate {env_path} && "
|
||||||
|
f"conda env config vars set {' '.join(env_spec.environment_variables)}",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
if env_set.returncode:
|
||||||
|
raise OSError(
|
||||||
|
"Failed to set environment variables:\n"
|
||||||
|
f" stdout: {env_set.stdout.decode('utf-8')}\n"
|
||||||
|
f" stderr: {env_set.stderr.decode('utf-8')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that they were actually set correctly.
|
||||||
|
actual_env_vars = subprocess.run(
|
||||||
|
f"source activate {env_path} && env",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
).stdout.decode("utf-8").strip().splitlines()
|
||||||
|
for e in env_spec.environment_variables:
|
||||||
|
assert e in actual_env_vars, f"{e} not in envs"
|
||||||
|
|
||||||
|
print(f"Building PyTorch for env: `{env_name}`")
|
||||||
|
# We have to re-run during each build to pick up the new
|
||||||
|
# build config settings.
|
||||||
|
build_run = subprocess.run(
|
||||||
|
f"source activate {env_path} && "
|
||||||
|
f"cd {git_root} && "
|
||||||
|
"python setup.py install --cmake",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Checking configuration:")
|
||||||
|
check_run = subprocess.run(
|
||||||
|
# Shameless abuse of `python -c ...`
|
||||||
|
f"source activate {env_path} && "
|
||||||
|
"python -c \""
|
||||||
|
"import torch;"
|
||||||
|
"from torch.utils.benchmark import Timer;"
|
||||||
|
"print(torch.__config__.show());"
|
||||||
|
"setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
|
||||||
|
"counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
|
||||||
|
"stats = counts.as_standardized().stats(inclusive=True);"
|
||||||
|
"print(stats.filter(lambda l: 'blas' in l.lower()))\"",
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
if check_run.returncode:
|
||||||
|
raise OSError(
|
||||||
|
"Failed to set environment variables:\n"
|
||||||
|
f" stdout: {check_run.stdout.decode('utf-8')}\n"
|
||||||
|
f" stderr: {check_run.stderr.decode('utf-8')}"
|
||||||
|
)
|
||||||
|
check_run_stdout = check_run.stdout.decode('utf-8')
|
||||||
|
print(check_run_stdout)
|
||||||
|
|
||||||
|
for e in env_spec.environment_variables:
|
||||||
|
if "BLAS" in e:
|
||||||
|
assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
|
||||||
|
|
||||||
|
for s in env_spec.expected_blas_symbols:
|
||||||
|
assert s in check_run_stdout
|
||||||
|
|
||||||
|
if env_spec.expected_mkl_version is not None:
|
||||||
|
assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
|
||||||
|
|
||||||
|
print(f"Build complete: {env_name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user