mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] split cpu vec isa to dedicate file (keep git history) (#129789)
This PR is the implemention of https://github.com/pytorch/pytorch/issues/124245#issuecomment-2197778902 plan 1 Changes: 1. Duplicate `codecache.py` to `cpu_vec_isa.py` with its `git history`. <img width="745" alt="image" src="https://github.com/pytorch/pytorch/assets/8433590/106533da-ce80-4825-8271-35ffb3141f92"> 2. Make `cpu_vec_isa.py` as dedicate file for CPU vec isa. It also good to extend for more archtectures and vec isa. 3. Update code for above changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129789 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
a676b7c5f3
commit
58f346c874
|
|
@ -18,7 +18,7 @@ from torch import nn
|
|||
from torch._C import FileCheck
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._dynamo.utils import same
|
||||
from torch._inductor import codecache, config, metrics, test_operators
|
||||
from torch._inductor import config, cpu_vec_isa, metrics, test_operators
|
||||
from torch._inductor.codegen.common import OptimizationContext
|
||||
from torch._inductor.codegen.cpp import (
|
||||
CppOverrides,
|
||||
|
|
@ -67,12 +67,12 @@ aten = torch.ops.aten
|
|||
check_model = test_torchinductor.check_model
|
||||
|
||||
requires_vectorization = unittest.skipUnless(
|
||||
codecache.valid_vec_isa_list(), "Does not support vectorization"
|
||||
cpu_vec_isa.valid_vec_isa_list(), "Does not support vectorization"
|
||||
)
|
||||
|
||||
|
||||
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
||||
if codecache.valid_vec_isa_list():
|
||||
if cpu_vec_isa.valid_vec_isa_list():
|
||||
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
|
||||
|
||||
|
||||
|
|
@ -1583,14 +1583,14 @@ class CPUReproTests(TestCase):
|
|||
self.common(fn, (value,))
|
||||
|
||||
@unittest.skipIf(
|
||||
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
|
||||
platform.machine() != "x86_64" or not cpu_vec_isa.valid_vec_isa_list(),
|
||||
"Does not support vectorization or not x86_64 machine",
|
||||
)
|
||||
@patch("torch.cuda.is_available", lambda: False)
|
||||
def test_auto_simd(self):
|
||||
vec_amx = codecache.supported_vec_isa_list[0]
|
||||
vec_avx512 = codecache.supported_vec_isa_list[1]
|
||||
vec_avx2 = codecache.supported_vec_isa_list[2]
|
||||
vec_amx = cpu_vec_isa.supported_vec_isa_list[0]
|
||||
vec_avx512 = cpu_vec_isa.supported_vec_isa_list[1]
|
||||
vec_avx2 = cpu_vec_isa.supported_vec_isa_list[2]
|
||||
self.assertTrue(vec_amx.bit_width() == 512)
|
||||
self.assertTrue(vec_amx.nelements() == 16)
|
||||
self.assertTrue(vec_amx.nelements(torch.bfloat16) == 32)
|
||||
|
|
@ -1602,43 +1602,43 @@ class CPUReproTests(TestCase):
|
|||
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
if vec_amx in codecache.valid_vec_isa_list():
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_amx)
|
||||
elif vec_avx512 in codecache.valid_vec_isa_list():
|
||||
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
else:
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
with config.patch({"cpp.simdlen": 0}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 1}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 257}):
|
||||
isa = codecache.pick_vec_isa()
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 513}):
|
||||
isa_list = codecache.valid_vec_isa_list()
|
||||
isa_list = cpu_vec_isa.valid_vec_isa_list()
|
||||
if vec_avx512 in isa_list:
|
||||
self.assertFalse(isa)
|
||||
|
||||
with config.patch({"cpp.simdlen": 512}):
|
||||
isa_list = codecache.valid_vec_isa_list()
|
||||
isa = codecache.pick_vec_isa()
|
||||
isa_list = cpu_vec_isa.valid_vec_isa_list()
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
if vec_amx in isa_list:
|
||||
self.assertTrue(isa == vec_amx)
|
||||
elif vec_avx512 in isa_list:
|
||||
self.assertTrue(isa == vec_avx512)
|
||||
|
||||
with config.patch({"cpp.simdlen": 256}):
|
||||
isa_list = codecache.valid_vec_isa_list()
|
||||
isa_list = cpu_vec_isa.valid_vec_isa_list()
|
||||
if vec_avx2 in isa_list:
|
||||
isa = codecache.pick_vec_isa()
|
||||
isa = cpu_vec_isa.pick_vec_isa()
|
||||
self.assertTrue(isa == vec_avx2)
|
||||
|
||||
@requires_vectorization
|
||||
|
|
@ -1989,7 +1989,9 @@ class CPUReproTests(TestCase):
|
|||
x[0, 0] = torch.nan
|
||||
x[1, -1] = torch.nan
|
||||
|
||||
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
|
||||
bit_widths = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()] + [
|
||||
None
|
||||
]
|
||||
for item in bit_widths:
|
||||
with config.patch({"cpp.simdlen": item}):
|
||||
torch._dynamo.reset()
|
||||
|
|
@ -2007,7 +2009,7 @@ class CPUReproTests(TestCase):
|
|||
|
||||
return fn
|
||||
|
||||
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
|
||||
bit_widths = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
|
||||
ih = [16, 65]
|
||||
iw = ih
|
||||
oh = ih
|
||||
|
|
@ -2266,7 +2268,7 @@ class CPUReproTests(TestCase):
|
|||
graph_lowering
|
||||
):
|
||||
# The moset inner loop variable is used in the index_expr
|
||||
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
||||
tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float)
|
||||
with CppVecKernelChecker(
|
||||
args=None, num_threads=1, tiling_factor=tiling_factor
|
||||
) as vec_checker:
|
||||
|
|
@ -2366,7 +2368,7 @@ class CPUReproTests(TestCase):
|
|||
):
|
||||
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
|
||||
|
||||
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
|
||||
tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float)
|
||||
# The most inner loop variable is used in the index_expr
|
||||
with CppVecKernelChecker(
|
||||
args=None, num_threads=1, tiling_factor=tiling_factor
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import torch._dynamo.config as dynamo_config
|
|||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.select_algorithm as select_algorithm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.codecache import VecAMX
|
||||
from torch._inductor.cpu_vec_isa import VecAMX
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ except ImportError:
|
|||
)
|
||||
|
||||
import torch._inductor.config as config
|
||||
from torch._inductor import codecache, metrics
|
||||
from torch._inductor import cpu_vec_isa, metrics
|
||||
from torch._inductor.codegen import cpp_utils
|
||||
from torch._inductor.codegen.common import (
|
||||
get_scheduling_for_device,
|
||||
|
|
@ -146,7 +146,7 @@ class ExtensionBackendTests(TestCase):
|
|||
metrics.reset()
|
||||
opt_fn = torch.compile()(fn)
|
||||
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
|
||||
if codecache.valid_vec_isa_list():
|
||||
if cpu_vec_isa.valid_vec_isa_list():
|
||||
load_expr = "loadu"
|
||||
else:
|
||||
load_expr = " = in_ptr0[static_cast<long>(i0)];"
|
||||
|
|
|
|||
|
|
@ -56,6 +56,12 @@ from torch._inductor.codegen.rocm.compile_command import (
|
|||
rocm_compile_command,
|
||||
rocm_compiler,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import (
|
||||
get_compiler_version_info,
|
||||
invalid_vec_isa,
|
||||
pick_vec_isa,
|
||||
VecISA,
|
||||
)
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
|
|
@ -1255,360 +1261,6 @@ def is_clang() -> bool:
|
|||
return bool(re.search(r"(clang|clang\+\+)", cpp_compiler()))
|
||||
|
||||
|
||||
def get_compiler_version_info(compiler):
|
||||
SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else ()
|
||||
env = os.environ.copy()
|
||||
env["LC_ALL"] = "C" # Don't localize output
|
||||
try:
|
||||
version_string = subprocess.check_output(
|
||||
[compiler, "-v"], stderr=subprocess.STDOUT, env=env
|
||||
).decode(*SUBPROCESS_DECODE_ARGS)
|
||||
except Exception as e:
|
||||
try:
|
||||
version_string = subprocess.check_output(
|
||||
[compiler, "--version"], stderr=subprocess.STDOUT, env=env
|
||||
).decode(*SUBPROCESS_DECODE_ARGS)
|
||||
except Exception as e:
|
||||
return ""
|
||||
# Mutiple lines to one line string.
|
||||
version_string = version_string.replace("\r", "_")
|
||||
version_string = version_string.replace("\n", "_")
|
||||
return version_string
|
||||
|
||||
|
||||
def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
||||
# ISA dry compile will cost about 1 sec time each startup time.
|
||||
# Please check the issue: https://github.com/pytorch/pytorch/issues/100378
|
||||
# Actually, dry compile is checking compile capability for ISA.
|
||||
# We just record the compiler version, isa options and pytorch version info,
|
||||
# and generated them to output binary hash path.
|
||||
# It would optimize and skip compile existing binary.
|
||||
compiler_info = get_compiler_version_info(cpp_compiler())
|
||||
torch_version = torch.__version__
|
||||
fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
|
||||
return fingerprint
|
||||
|
||||
|
||||
class VecISA:
|
||||
_bit_width: int
|
||||
_macro: List[str]
|
||||
_arch_flags: str
|
||||
_dtype_nelements: Dict[torch.dtype, int]
|
||||
|
||||
# Note [Checking for Vectorized Support in Inductor]
|
||||
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
||||
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
|
||||
# like exp, pow, sin, cos and etc.
|
||||
# But PyTorch and TorchInductor might use different compilers to build code. If
|
||||
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
|
||||
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
|
||||
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
|
||||
# gcc/g++ compiler by default while it could support the AVX512 compilation.
|
||||
# Therefore, there would be a conflict sleef version between PyTorch and
|
||||
# TorchInductor. Hence, we dry-compile the following code to check whether current
|
||||
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
|
||||
# also needs the logic
|
||||
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
|
||||
# making the runtime check unnecessary.
|
||||
_avx_code = """
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#endif
|
||||
|
||||
alignas(64) float in_out_ptr0[16] = {0.0};
|
||||
|
||||
extern "C" void __avx_chk_kernel() {
|
||||
auto tmp0 = at::vec::Vectorized<float>(1);
|
||||
auto tmp1 = tmp0.exp();
|
||||
tmp1.store(in_out_ptr0);
|
||||
}
|
||||
""" # noqa: B950
|
||||
|
||||
_avx_py_load = """
|
||||
import torch
|
||||
from ctypes import cdll
|
||||
cdll.LoadLibrary("__lib_path__")
|
||||
"""
|
||||
|
||||
def bit_width(self) -> int:
|
||||
return self._bit_width
|
||||
|
||||
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
||||
return self._dtype_nelements[dtype]
|
||||
|
||||
def build_macro(self) -> List[str]:
|
||||
return self._macro
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
return self._arch_flags
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def check_build(self, code) -> bool:
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
|
||||
|
||||
key, input_path = write(
|
||||
code,
|
||||
"cpp",
|
||||
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
||||
)
|
||||
from filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
output_dir = os.path.dirname(input_path)
|
||||
buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
|
||||
x86_isa_help_builder = CppBuilder(
|
||||
key,
|
||||
[input_path],
|
||||
buid_options,
|
||||
output_dir,
|
||||
)
|
||||
try:
|
||||
# Check if the output file exist, and compile when not.
|
||||
output_path = x86_isa_help_builder.get_target_file_path()
|
||||
if not os.path.isfile(output_path):
|
||||
status, target_file = x86_isa_help_builder.build()
|
||||
|
||||
# Check build result
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
VecISA._avx_py_load.replace("__lib_path__", output_path),
|
||||
],
|
||||
stderr=subprocess.DEVNULL,
|
||||
env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
|
||||
)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if config.cpp.vec_isa_ok is not None:
|
||||
return config.cpp.vec_isa_ok
|
||||
|
||||
if config.is_fbcode():
|
||||
return True
|
||||
|
||||
return self.check_build(VecISA._avx_code)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecNEON(VecISA):
|
||||
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
|
||||
_macro = ["CPU_CAPABILITY_NEON"]
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
_macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
|
||||
_arch_flags = "" # Unused
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX512(VecISA):
|
||||
_bit_width = 512
|
||||
_macro = ["CPU_CAPABILITY_AVX512"]
|
||||
_arch_flags = (
|
||||
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
||||
if not _IS_WINDOWS
|
||||
else "/arch:AVX512"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "avx512"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAMX(VecAVX512):
|
||||
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__() + " amx_tile"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
_amx_code = """
|
||||
#include <cstdint>
|
||||
#include <immintrin.h>
|
||||
|
||||
struct amx_tilecfg {
|
||||
uint8_t palette_id;
|
||||
uint8_t start_row;
|
||||
uint8_t reserved_0[14];
|
||||
uint16_t colsb[16];
|
||||
uint8_t rows[16];
|
||||
};
|
||||
|
||||
extern "C" void __amx_chk_kernel() {
|
||||
amx_tilecfg cfg = {0};
|
||||
_tile_loadconfig(&cfg);
|
||||
_tile_zero(0);
|
||||
_tile_dpbf16ps(0, 1, 2);
|
||||
_tile_dpbusd(0, 1, 2);
|
||||
}
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if super().__bool__():
|
||||
if config.is_fbcode():
|
||||
return False
|
||||
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX2(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = ["CPU_CAPABILITY_AVX2"]
|
||||
_arch_flags = (
|
||||
"-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "avx2"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecZVECTOR(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = [
|
||||
"CPU_CAPABILITY_ZVECTOR",
|
||||
"CPU_CAPABILITY=ZVECTOR",
|
||||
"HAVE_ZVECTOR_CPU_DEFINITION",
|
||||
]
|
||||
_arch_flags = "-mvx -mzvector"
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "zvector"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
class InvalidVecISA(VecISA):
|
||||
_bit_width = 0
|
||||
_macro = [""]
|
||||
_arch_flags = ""
|
||||
_dtype_nelements = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "INVALID_VEC_ISA"
|
||||
|
||||
def __bool__(self) -> bool: # type: ignore[override]
|
||||
return False
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
def x86_isa_checker() -> List[str]:
|
||||
supported_isa: List[str] = []
|
||||
|
||||
def _check_and_append_supported_isa(
|
||||
dest: List[str], isa_supported: bool, isa_name: str
|
||||
):
|
||||
if isa_supported:
|
||||
dest.append(isa_name)
|
||||
|
||||
Arch = platform.machine()
|
||||
"""
|
||||
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
"""
|
||||
if Arch != "x86_64" and Arch != "AMD64":
|
||||
return supported_isa
|
||||
|
||||
avx2 = torch.cpu._is_cpu_support_avx2()
|
||||
avx512 = torch.cpu._is_cpu_support_avx512()
|
||||
amx_tile = torch.cpu._is_cpu_support_amx_tile()
|
||||
|
||||
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
||||
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
||||
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
|
||||
|
||||
return supported_isa
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
|
||||
|
||||
|
||||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@functools.lru_cache(None)
|
||||
def valid_vec_isa_list() -> List[VecISA]:
|
||||
isa_list: List[VecISA] = []
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
isa_list.append(VecNEON())
|
||||
|
||||
if sys.platform not in ["linux", "win32"]:
|
||||
return isa_list
|
||||
|
||||
arch = platform.machine()
|
||||
if arch == "s390x":
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
while True:
|
||||
line = _cpu_info.readline()
|
||||
if not line:
|
||||
break
|
||||
# process line
|
||||
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
|
||||
if featuresmatch:
|
||||
for group in featuresmatch.groups():
|
||||
if re.search(r"[\^ ]+vxe[\$ ]+", group):
|
||||
isa_list.append(VecZVECTOR())
|
||||
break
|
||||
elif arch == "aarch64":
|
||||
isa_list.append(VecNEON())
|
||||
elif arch in ["x86_64", "AMD64"]:
|
||||
"""
|
||||
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
"""
|
||||
_cpu_supported_x86_isa = x86_isa_checker()
|
||||
for isa in supported_vec_isa_list:
|
||||
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
||||
isa_list.append(isa)
|
||||
|
||||
return isa_list
|
||||
|
||||
|
||||
def pick_vec_isa() -> VecISA:
|
||||
if config.is_fbcode():
|
||||
return VecAVX2()
|
||||
|
||||
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
||||
if not _valid_vec_isa_list:
|
||||
return invalid_vec_isa
|
||||
|
||||
# If the simdlen is None, it indicates determine the vectorization length automatically
|
||||
if config.cpp.simdlen is None:
|
||||
assert _valid_vec_isa_list
|
||||
return _valid_vec_isa_list[0]
|
||||
|
||||
for isa in _valid_vec_isa_list:
|
||||
if config.cpp.simdlen == isa.bit_width():
|
||||
return isa
|
||||
|
||||
return invalid_vec_isa
|
||||
|
||||
|
||||
def get_compile_only(compile_only: bool = True) -> str:
|
||||
return "-c" if compile_only else ""
|
||||
|
||||
|
|
@ -3537,7 +3189,7 @@ class ROCmCodeCache:
|
|||
"""
|
||||
if not cls._logged_compiler_version:
|
||||
cls._logged_compiler_version = True
|
||||
log.debug(get_compiler_version_info(rocm_compiler()))
|
||||
log.debug(get_compiler_version_info(str(rocm_compiler())))
|
||||
|
||||
key, input_path = cls.write(source_code, dst_file_ext)
|
||||
if key not in cls.cache:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
|||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
|
||||
from ..._dynamo.utils import counters
|
||||
|
||||
from .. import codecache, config, ir, metrics
|
||||
from .. import codecache, config, cpu_vec_isa, ir, metrics
|
||||
from ..codegen.wrapper import WrapperCodeGen
|
||||
from ..optimize_indexing import range_expressable_in_32_bits
|
||||
from ..scheduler import (
|
||||
|
|
@ -2090,7 +2090,7 @@ class CppVecKernel(CppKernel):
|
|||
tiling_dtype=torch.float,
|
||||
):
|
||||
super().__init__(args, num_threads)
|
||||
self.vec_isa = codecache.pick_vec_isa()
|
||||
self.vec_isa = cpu_vec_isa.pick_vec_isa()
|
||||
assert self.vec_isa
|
||||
if tiling_factor == 0:
|
||||
tiling_factor = self.vec_isa.nelements(dtype=tiling_dtype)
|
||||
|
|
@ -3044,7 +3044,7 @@ class CppKernelProxy(CppKernel):
|
|||
self.kernel_group = kernel_group
|
||||
self.loop_nest = None
|
||||
self.call_ranges = None
|
||||
self.picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
|
||||
self.picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
|
||||
|
||||
def data_type_propagation(self, nodes):
|
||||
for _node in nodes:
|
||||
|
|
@ -4056,15 +4056,15 @@ class LoopLevel:
|
|||
kernel: Optional[CppKernel] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Regarding the C++/OpenMP backend, `codecache.pick_vec_isa()` to check
|
||||
# Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check
|
||||
# vectorization ISA is a time-consuming and one-shot operation. It leads
|
||||
# to taking a longer time to import `codegen.cpp` package because the
|
||||
# `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while
|
||||
# the decorator will invoke `codecache.pick_vec_isa()` to initialize the
|
||||
# the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the
|
||||
# `simd_nelements` of the `LoopLevel`. It might introduce additional compilation
|
||||
# overhead to the Triton backend. Therefore, we moved the `simd_nelements` to
|
||||
# `__post_init__`
|
||||
picked_vec_isa: codecache.VecISA = codecache.pick_vec_isa()
|
||||
picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
|
||||
self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0
|
||||
|
||||
def get_kernels(self) -> List[CppKernel]:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import sympy
|
|||
import torch
|
||||
|
||||
from .. import ir
|
||||
from ..codecache import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
|
||||
from ..cpu_vec_isa import pick_vec_isa, VecAMX, VecAVX2, VecAVX512, VecISA
|
||||
from ..utils import IndentedBuffer, parallel_num_threads
|
||||
from ..virtualized import V
|
||||
from .common import KernelTemplate
|
||||
|
|
|
|||
|
|
@ -20,15 +20,16 @@ from typing import List, Sequence, Tuple, Union
|
|||
|
||||
import torch
|
||||
from torch._inductor import config, exc
|
||||
|
||||
# TODO: import below objects in function scope, in further optimization
|
||||
from torch._inductor.codecache import (
|
||||
_get_python_include_dirs,
|
||||
_LINKER_SCRIPT,
|
||||
_transform_cuda_paths,
|
||||
get_lock_dir,
|
||||
invalid_vec_isa,
|
||||
LOCK_TIMEOUT,
|
||||
VecISA,
|
||||
)
|
||||
from torch._inductor.cpu_vec_isa import invalid_vec_isa, VecISA
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
|
||||
if config.is_fbcode():
|
||||
|
|
@ -134,6 +135,10 @@ def _get_cpp_compiler() -> str:
|
|||
return compiler
|
||||
|
||||
|
||||
def cpp_compiler() -> str:
|
||||
return _get_cpp_compiler()
|
||||
|
||||
|
||||
def _is_gcc(cpp_compiler) -> bool:
|
||||
return bool(re.search(r"(gcc|g\+\+)", cpp_compiler))
|
||||
|
||||
|
|
|
|||
372
torch/_inductor/cpu_vec_isa.py
Normal file
372
torch/_inductor/cpu_vec_isa.py
Normal file
|
|
@ -0,0 +1,372 @@
|
|||
import dataclasses
|
||||
import functools
|
||||
import os
|
||||
import platform
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
|
||||
_IS_WINDOWS = sys.platform == "win32"
|
||||
|
||||
|
||||
# TODO: Move to cpp_builder, when optimize it.
|
||||
def get_compiler_version_info(compiler: str) -> str:
|
||||
SUBPROCESS_DECODE_ARGS = ("oem",) if _IS_WINDOWS else ()
|
||||
env = os.environ.copy()
|
||||
env["LC_ALL"] = "C" # Don't localize output
|
||||
try:
|
||||
version_string = subprocess.check_output(
|
||||
[compiler, "-v"], stderr=subprocess.STDOUT, env=env
|
||||
).decode(*SUBPROCESS_DECODE_ARGS)
|
||||
except Exception as e:
|
||||
try:
|
||||
version_string = subprocess.check_output(
|
||||
[compiler, "--version"], stderr=subprocess.STDOUT, env=env
|
||||
).decode(*SUBPROCESS_DECODE_ARGS)
|
||||
except Exception as e:
|
||||
return ""
|
||||
# Mutiple lines to one line string.
|
||||
version_string = version_string.replace("\r", "_")
|
||||
version_string = version_string.replace("\n", "_")
|
||||
return version_string
|
||||
|
||||
|
||||
def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
||||
# ISA dry compile will cost about 1 sec time each startup time.
|
||||
# Please check the issue: https://github.com/pytorch/pytorch/issues/100378
|
||||
# Actually, dry compile is checking compile capability for ISA.
|
||||
# We just record the compiler version, isa options and pytorch version info,
|
||||
# and generated them to output binary hash path.
|
||||
# It would optimize and skip compile existing binary.
|
||||
from torch._inductor.cpp_builder import cpp_compiler
|
||||
|
||||
compiler_info = get_compiler_version_info(cpp_compiler())
|
||||
torch_version = torch.__version__
|
||||
fingerprint = f"{compiler_info}={isa_flags}={torch_version}"
|
||||
return fingerprint
|
||||
|
||||
|
||||
class VecISA:
|
||||
_bit_width: int
|
||||
_macro: List[str]
|
||||
_arch_flags: str
|
||||
_dtype_nelements: Dict[torch.dtype, int]
|
||||
|
||||
# Note [Checking for Vectorized Support in Inductor]
|
||||
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
||||
# Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
|
||||
# like exp, pow, sin, cos and etc.
|
||||
# But PyTorch and TorchInductor might use different compilers to build code. If
|
||||
# PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
|
||||
# will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
|
||||
# avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
|
||||
# gcc/g++ compiler by default while it could support the AVX512 compilation.
|
||||
# Therefore, there would be a conflict sleef version between PyTorch and
|
||||
# TorchInductor. Hence, we dry-compile the following code to check whether current
|
||||
# HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
|
||||
# also needs the logic
|
||||
# In fbcode however, we are using the same compiler for pytorch and for inductor codegen,
|
||||
# making the runtime check unnecessary.
|
||||
_avx_code = """
|
||||
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON)
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#endif
|
||||
|
||||
alignas(64) float in_out_ptr0[16] = {0.0};
|
||||
|
||||
extern "C" void __avx_chk_kernel() {
|
||||
auto tmp0 = at::vec::Vectorized<float>(1);
|
||||
auto tmp1 = tmp0.exp();
|
||||
tmp1.store(in_out_ptr0);
|
||||
}
|
||||
""" # noqa: B950
|
||||
|
||||
_avx_py_load = """
|
||||
import torch
|
||||
from ctypes import cdll
|
||||
cdll.LoadLibrary("__lib_path__")
|
||||
"""
|
||||
|
||||
def bit_width(self) -> int:
|
||||
return self._bit_width
|
||||
|
||||
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
||||
return self._dtype_nelements[dtype]
|
||||
|
||||
def build_macro(self) -> List[str]:
|
||||
return self._macro
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
return self._arch_flags
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
def check_build(self, code: str) -> bool:
|
||||
from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT, write
|
||||
from torch._inductor.cpp_builder import CppBuilder, CppTorchOptions
|
||||
|
||||
key, input_path = write(
|
||||
code,
|
||||
"cpp",
|
||||
extra=_get_isa_dry_compile_fingerprint(self._arch_flags),
|
||||
)
|
||||
from filelock import FileLock
|
||||
|
||||
lock_dir = get_lock_dir()
|
||||
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
|
||||
with lock:
|
||||
output_dir = os.path.dirname(input_path)
|
||||
buid_options = CppTorchOptions(vec_isa=self, warning_all=False)
|
||||
x86_isa_help_builder = CppBuilder(
|
||||
key,
|
||||
[input_path],
|
||||
buid_options,
|
||||
output_dir,
|
||||
)
|
||||
try:
|
||||
# Check if the output file exist, and compile when not.
|
||||
output_path = x86_isa_help_builder.get_target_file_path()
|
||||
if not os.path.isfile(output_path):
|
||||
status, target_file = x86_isa_help_builder.build()
|
||||
|
||||
# Check build result
|
||||
subprocess.check_call(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
VecISA._avx_py_load.replace("__lib_path__", output_path),
|
||||
],
|
||||
stderr=subprocess.DEVNULL,
|
||||
env={**os.environ, "PYTHONPATH": ":".join(sys.path)},
|
||||
)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if config.cpp.vec_isa_ok is not None:
|
||||
return config.cpp.vec_isa_ok
|
||||
|
||||
if config.is_fbcode():
|
||||
return True
|
||||
|
||||
return self.check_build(VecISA._avx_code)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecNEON(VecISA):
|
||||
_bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h
|
||||
_macro = ["CPU_CAPABILITY_NEON"]
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
_macro.append("AT_BUILD_ARM_VEC256_WITH_SLEEF")
|
||||
_arch_flags = "" # Unused
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "asimd" # detects the presence of advanced SIMD on armv8-a kernels
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX512(VecISA):
|
||||
_bit_width = 512
|
||||
_macro = ["CPU_CAPABILITY_AVX512"]
|
||||
_arch_flags = (
|
||||
"-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
|
||||
if not _IS_WINDOWS
|
||||
else "/arch:AVX512"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 16, torch.bfloat16: 32, torch.float16: 32}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "avx512"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAMX(VecAVX512):
|
||||
_arch_flags = VecAVX512._arch_flags + " -mamx-tile -mamx-bf16 -mamx-int8"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return super().__str__() + " amx_tile"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
_amx_code = """
|
||||
#include <cstdint>
|
||||
#include <immintrin.h>
|
||||
|
||||
struct amx_tilecfg {
|
||||
uint8_t palette_id;
|
||||
uint8_t start_row;
|
||||
uint8_t reserved_0[14];
|
||||
uint16_t colsb[16];
|
||||
uint8_t rows[16];
|
||||
};
|
||||
|
||||
extern "C" void __amx_chk_kernel() {
|
||||
amx_tilecfg cfg = {0};
|
||||
_tile_loadconfig(&cfg);
|
||||
_tile_zero(0);
|
||||
_tile_dpbf16ps(0, 1, 2);
|
||||
_tile_dpbusd(0, 1, 2);
|
||||
}
|
||||
"""
|
||||
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def __bool__(self) -> bool:
|
||||
if super().__bool__():
|
||||
if config.is_fbcode():
|
||||
return False
|
||||
if self.check_build(VecAMX._amx_code) and torch.cpu._init_amx():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecAVX2(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = ["CPU_CAPABILITY_AVX2"]
|
||||
_arch_flags = (
|
||||
"-mavx2 -mfma -mf16c" if not _IS_WINDOWS else "/arch:AVX2"
|
||||
) # TODO: use cflags
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "avx2"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VecZVECTOR(VecISA):
|
||||
_bit_width = 256
|
||||
_macro = [
|
||||
"CPU_CAPABILITY_ZVECTOR",
|
||||
"CPU_CAPABILITY=ZVECTOR",
|
||||
"HAVE_ZVECTOR_CPU_DEFINITION",
|
||||
]
|
||||
_arch_flags = "-mvx -mzvector"
|
||||
_dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "zvector"
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
class InvalidVecISA(VecISA):
|
||||
_bit_width = 0
|
||||
_macro = [""]
|
||||
_arch_flags = ""
|
||||
_dtype_nelements = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "INVALID_VEC_ISA"
|
||||
|
||||
def __bool__(self) -> bool: # type: ignore[override]
|
||||
return False
|
||||
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
def x86_isa_checker() -> List[str]:
|
||||
supported_isa: List[str] = []
|
||||
|
||||
def _check_and_append_supported_isa(
|
||||
dest: List[str], isa_supported: bool, isa_name: str
|
||||
) -> None:
|
||||
if isa_supported:
|
||||
dest.append(isa_name)
|
||||
|
||||
Arch = platform.machine()
|
||||
"""
|
||||
Arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
"""
|
||||
if Arch != "x86_64" and Arch != "AMD64":
|
||||
return supported_isa
|
||||
|
||||
avx2 = torch.cpu._is_cpu_support_avx2()
|
||||
avx512 = torch.cpu._is_cpu_support_avx512()
|
||||
amx_tile = torch.cpu._is_cpu_support_amx_tile()
|
||||
|
||||
_check_and_append_supported_isa(supported_isa, avx2, "avx2")
|
||||
_check_and_append_supported_isa(supported_isa, avx512, "avx512")
|
||||
_check_and_append_supported_isa(supported_isa, amx_tile, "amx_tile")
|
||||
|
||||
return supported_isa
|
||||
|
||||
|
||||
invalid_vec_isa = InvalidVecISA()
|
||||
supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()]
|
||||
|
||||
|
||||
# Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@functools.lru_cache(None)
|
||||
def valid_vec_isa_list() -> List[VecISA]:
|
||||
isa_list: List[VecISA] = []
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
isa_list.append(VecNEON())
|
||||
|
||||
if sys.platform not in ["linux", "win32"]:
|
||||
return isa_list
|
||||
|
||||
arch = platform.machine()
|
||||
if arch == "s390x":
|
||||
with open("/proc/cpuinfo") as _cpu_info:
|
||||
while True:
|
||||
line = _cpu_info.readline()
|
||||
if not line:
|
||||
break
|
||||
# process line
|
||||
featuresmatch = re.match(r"^features\s*:\s*(.*)$", line)
|
||||
if featuresmatch:
|
||||
for group in featuresmatch.groups():
|
||||
if re.search(r"[\^ ]+vxe[\$ ]+", group):
|
||||
isa_list.append(VecZVECTOR())
|
||||
break
|
||||
elif arch == "aarch64":
|
||||
isa_list.append(VecNEON())
|
||||
elif arch in ["x86_64", "AMD64"]:
|
||||
"""
|
||||
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||
"""
|
||||
_cpu_supported_x86_isa = x86_isa_checker()
|
||||
for isa in supported_vec_isa_list:
|
||||
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
||||
isa_list.append(isa)
|
||||
|
||||
return isa_list
|
||||
|
||||
|
||||
def pick_vec_isa() -> VecISA:
|
||||
if config.is_fbcode():
|
||||
return VecAVX2()
|
||||
|
||||
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
||||
if not _valid_vec_isa_list:
|
||||
return invalid_vec_isa
|
||||
|
||||
# If the simdlen is None, it indicates determine the vectorization length automatically
|
||||
if config.cpp.simdlen is None:
|
||||
assert _valid_vec_isa_list
|
||||
return _valid_vec_isa_list[0]
|
||||
|
||||
for isa in _valid_vec_isa_list:
|
||||
if config.cpp.simdlen == isa.bit_width():
|
||||
return isa
|
||||
|
||||
return invalid_vec_isa
|
||||
Loading…
Reference in New Issue
Block a user