mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
xpu: get xpu arch flags at runtime in cpp_extensions (#152192)
This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list. CC: @malfet, @jingxu10, @EikanWang, @guangyey Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192 Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
parent
9fa07340fd
commit
aca2c99a65
|
|
@ -3,8 +3,10 @@
|
||||||
import glob
|
import glob
|
||||||
import locale
|
import locale
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
|
import string
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
|
||||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||||
self.assertEqual(z, torch.ones_like(z))
|
self.assertEqual(z, torch.ones_like(z))
|
||||||
|
|
||||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
def _test_jit_xpu_extension(self):
|
||||||
def test_jit_xpu_extension(self):
|
name = "torch_test_xpu_extension_"
|
||||||
# NOTE: The name of the extension must equal the name of the module.
|
# randomizing name for the case when we test building few extensions
|
||||||
|
# in a row using this function
|
||||||
|
name += "".join(random.sample(string.ascii_letters, 5))
|
||||||
module = torch.utils.cpp_extension.load(
|
module = torch.utils.cpp_extension.load(
|
||||||
name="torch_test_xpu_extension",
|
name=name,
|
||||||
sources=[
|
sources=[
|
||||||
"cpp_extensions/xpu_extension.sycl",
|
"cpp_extensions/xpu_extension.sycl",
|
||||||
],
|
],
|
||||||
|
|
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
|
||||||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||||
self.assertEqual(z, torch.ones_like(z))
|
self.assertEqual(z, torch.ones_like(z))
|
||||||
|
|
||||||
|
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||||
|
def test_jit_xpu_extension(self):
|
||||||
|
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
|
||||||
|
self._test_jit_xpu_extension()
|
||||||
|
|
||||||
|
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||||
|
def test_jit_xpu_archlists(self):
|
||||||
|
# NOTE: in this test we explicitly test few different options
|
||||||
|
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
|
||||||
|
# environment before the test won't affect it.
|
||||||
|
archlists = [
|
||||||
|
"", # expecting JIT compilation
|
||||||
|
",".join(torch.xpu.get_arch_list()),
|
||||||
|
]
|
||||||
|
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
|
||||||
|
try:
|
||||||
|
for al in archlists:
|
||||||
|
os.environ["TORCH_XPU_ARCH_LIST"] = al
|
||||||
|
self._test_jit_xpu_extension()
|
||||||
|
finally:
|
||||||
|
if old_envvar is None:
|
||||||
|
os.environ.pop("TORCH_XPU_ARCH_LIST")
|
||||||
|
else:
|
||||||
|
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||||
def test_mps_extension(self):
|
def test_mps_extension(self):
|
||||||
module = torch.utils.cpp_extension.load(
|
module = torch.utils.cpp_extension.load(
|
||||||
|
|
|
||||||
|
|
@ -291,16 +291,25 @@ def _get_sycl_arch_list():
|
||||||
# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled
|
# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled
|
||||||
# for default spir64 target and avoid device specific compilations entirely. Further, kernels
|
# for default spir64 target and avoid device specific compilations entirely. Further, kernels
|
||||||
# will be JIT compiled at runtime.
|
# will be JIT compiled at runtime.
|
||||||
|
def _get_sycl_target_flags():
|
||||||
|
if _get_sycl_arch_list() != '':
|
||||||
|
return ['-fsycl-targets=spir64_gen,spir64']
|
||||||
|
return ['']
|
||||||
|
|
||||||
|
def _get_sycl_device_flags():
|
||||||
|
arch_list = _get_sycl_arch_list()
|
||||||
|
if arch_list != '':
|
||||||
|
return [f'-Xs "-device {arch_list}"']
|
||||||
|
return ['']
|
||||||
|
|
||||||
_COMMON_SYCL_FLAGS = [
|
_COMMON_SYCL_FLAGS = [
|
||||||
'-fsycl',
|
'-fsycl',
|
||||||
'-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
_SYCL_DLINK_FLAGS = [
|
_SYCL_DLINK_FLAGS = [
|
||||||
*_COMMON_SYCL_FLAGS,
|
*_COMMON_SYCL_FLAGS,
|
||||||
'-fsycl-link',
|
'-fsycl-link',
|
||||||
'--offload-compress',
|
'--offload-compress',
|
||||||
f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
||||||
|
|
@ -812,6 +821,7 @@ class BuildExtension(build_ext):
|
||||||
sycl_dlink_post_cflags = None
|
sycl_dlink_post_cflags = None
|
||||||
if with_sycl:
|
if with_sycl:
|
||||||
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
|
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
|
||||||
|
sycl_cflags += _get_sycl_target_flags()
|
||||||
if isinstance(extra_postargs, dict):
|
if isinstance(extra_postargs, dict):
|
||||||
sycl_post_cflags = extra_postargs['sycl']
|
sycl_post_cflags = extra_postargs['sycl']
|
||||||
else:
|
else:
|
||||||
|
|
@ -829,6 +839,7 @@ class BuildExtension(build_ext):
|
||||||
sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
|
sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
|
||||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||||
|
sycl_dlink_post_cflags += _get_sycl_device_flags()
|
||||||
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
|
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
|
||||||
|
|
||||||
_write_ninja_file_and_compile_objects(
|
_write_ninja_file_and_compile_objects(
|
||||||
|
|
@ -2688,6 +2699,7 @@ def _write_ninja_file_to_build_library(path,
|
||||||
|
|
||||||
if with_sycl:
|
if with_sycl:
|
||||||
sycl_cflags = cflags + _COMMON_SYCL_FLAGS
|
sycl_cflags = cflags + _COMMON_SYCL_FLAGS
|
||||||
|
sycl_cflags += _get_sycl_target_flags()
|
||||||
sycl_cflags += extra_sycl_cflags
|
sycl_cflags += extra_sycl_cflags
|
||||||
_append_sycl_std_if_no_std_present(sycl_cflags)
|
_append_sycl_std_if_no_std_present(sycl_cflags)
|
||||||
host_cflags = cflags
|
host_cflags = cflags
|
||||||
|
|
@ -2696,6 +2708,7 @@ def _write_ninja_file_to_build_library(path,
|
||||||
host_cflags = ' '.join(host_cflags)
|
host_cflags = ' '.join(host_cflags)
|
||||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||||
|
sycl_dlink_post_cflags += _get_sycl_device_flags()
|
||||||
else:
|
else:
|
||||||
sycl_cflags = None
|
sycl_cflags = None
|
||||||
sycl_dlink_post_cflags = None
|
sycl_dlink_post_cflags = None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user