mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
xpu: get xpu arch flags at runtime in cpp_extensions (#152192)
This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list. CC: @malfet, @jingxu10, @EikanWang, @guangyey Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192 Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
parent
9fa07340fd
commit
aca2c99a65
|
|
@ -3,8 +3,10 @@
|
|||
import glob
|
||||
import locale
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
|
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: The name of the extension must equal the name of the module.
|
||||
def _test_jit_xpu_extension(self):
|
||||
name = "torch_test_xpu_extension_"
|
||||
# randomizing name for the case when we test building few extensions
|
||||
# in a row using this function
|
||||
name += "".join(random.sample(string.ascii_letters, 5))
|
||||
module = torch.utils.cpp_extension.load(
|
||||
name="torch_test_xpu_extension",
|
||||
name=name,
|
||||
sources=[
|
||||
"cpp_extensions/xpu_extension.sycl",
|
||||
],
|
||||
|
|
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
|
|||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_extension(self):
|
||||
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
|
||||
self._test_jit_xpu_extension()
|
||||
|
||||
@unittest.skipIf(not (TEST_XPU), "XPU not found")
|
||||
def test_jit_xpu_archlists(self):
|
||||
# NOTE: in this test we explicitly test few different options
|
||||
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
|
||||
# environment before the test won't affect it.
|
||||
archlists = [
|
||||
"", # expecting JIT compilation
|
||||
",".join(torch.xpu.get_arch_list()),
|
||||
]
|
||||
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
|
||||
try:
|
||||
for al in archlists:
|
||||
os.environ["TORCH_XPU_ARCH_LIST"] = al
|
||||
self._test_jit_xpu_extension()
|
||||
finally:
|
||||
if old_envvar is None:
|
||||
os.environ.pop("TORCH_XPU_ARCH_LIST")
|
||||
else:
|
||||
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
|
||||
|
||||
@unittest.skipIf(not TEST_MPS, "MPS not found")
|
||||
def test_mps_extension(self):
|
||||
module = torch.utils.cpp_extension.load(
|
||||
|
|
|
|||
|
|
@ -291,16 +291,25 @@ def _get_sycl_arch_list():
|
|||
# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled
|
||||
# for default spir64 target and avoid device specific compilations entirely. Further, kernels
|
||||
# will be JIT compiled at runtime.
|
||||
def _get_sycl_target_flags():
|
||||
if _get_sycl_arch_list() != '':
|
||||
return ['-fsycl-targets=spir64_gen,spir64']
|
||||
return ['']
|
||||
|
||||
def _get_sycl_device_flags():
|
||||
arch_list = _get_sycl_arch_list()
|
||||
if arch_list != '':
|
||||
return [f'-Xs "-device {arch_list}"']
|
||||
return ['']
|
||||
|
||||
_COMMON_SYCL_FLAGS = [
|
||||
'-fsycl',
|
||||
'-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '',
|
||||
]
|
||||
|
||||
_SYCL_DLINK_FLAGS = [
|
||||
*_COMMON_SYCL_FLAGS,
|
||||
'-fsycl-link',
|
||||
'--offload-compress',
|
||||
f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '',
|
||||
]
|
||||
|
||||
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
|
||||
|
|
@ -812,6 +821,7 @@ class BuildExtension(build_ext):
|
|||
sycl_dlink_post_cflags = None
|
||||
if with_sycl:
|
||||
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
|
||||
sycl_cflags += _get_sycl_target_flags()
|
||||
if isinstance(extra_postargs, dict):
|
||||
sycl_post_cflags = extra_postargs['sycl']
|
||||
else:
|
||||
|
|
@ -829,6 +839,7 @@ class BuildExtension(build_ext):
|
|||
sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
|
||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||
sycl_dlink_post_cflags += _get_sycl_device_flags()
|
||||
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
|
||||
|
||||
_write_ninja_file_and_compile_objects(
|
||||
|
|
@ -2688,6 +2699,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
|
||||
if with_sycl:
|
||||
sycl_cflags = cflags + _COMMON_SYCL_FLAGS
|
||||
sycl_cflags += _get_sycl_target_flags()
|
||||
sycl_cflags += extra_sycl_cflags
|
||||
_append_sycl_std_if_no_std_present(sycl_cflags)
|
||||
host_cflags = cflags
|
||||
|
|
@ -2696,6 +2708,7 @@ def _write_ninja_file_to_build_library(path,
|
|||
host_cflags = ' '.join(host_cflags)
|
||||
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
|
||||
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
|
||||
sycl_dlink_post_cflags += _get_sycl_device_flags()
|
||||
else:
|
||||
sycl_cflags = None
|
||||
sycl_dlink_post_cflags = None
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user