xpu: get xpu arch flags at runtime in cpp_extensions (#152192)

This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list.

CC: @malfet, @jingxu10, @EikanWang, @guangyey

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
Dmitry Rogozhkin 2025-05-09 05:43:47 +00:00 committed by PyTorch MergeBot
parent 9fa07340fd
commit aca2c99a65
2 changed files with 48 additions and 6 deletions

View File

@ -3,8 +3,10 @@
import glob
import locale
import os
import random
import re
import shutil
import string
import subprocess
import sys
import tempfile
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: The name of the extension must equal the name of the module.
def _test_jit_xpu_extension(self):
name = "torch_test_xpu_extension_"
# randomizing name for the case when we test building few extensions
# in a row using this function
name += "".join(random.sample(string.ascii_letters, 5))
module = torch.utils.cpp_extension.load(
name="torch_test_xpu_extension",
name=name,
sources=[
"cpp_extensions/xpu_extension.sycl",
],
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
self._test_jit_xpu_extension()
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_archlists(self):
# NOTE: in this test we explicitly test few different options
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
# environment before the test won't affect it.
archlists = [
"", # expecting JIT compilation
",".join(torch.xpu.get_arch_list()),
]
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
try:
for al in archlists:
os.environ["TORCH_XPU_ARCH_LIST"] = al
self._test_jit_xpu_extension()
finally:
if old_envvar is None:
os.environ.pop("TORCH_XPU_ARCH_LIST")
else:
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
@unittest.skipIf(not TEST_MPS, "MPS not found")
def test_mps_extension(self):
module = torch.utils.cpp_extension.load(

View File

@ -291,16 +291,25 @@ def _get_sycl_arch_list():
# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled
# for default spir64 target and avoid device specific compilations entirely. Further, kernels
# will be JIT compiled at runtime.
def _get_sycl_target_flags():
if _get_sycl_arch_list() != '':
return ['-fsycl-targets=spir64_gen,spir64']
return ['']
def _get_sycl_device_flags():
arch_list = _get_sycl_arch_list()
if arch_list != '':
return [f'-Xs "-device {arch_list}"']
return ['']
_COMMON_SYCL_FLAGS = [
'-fsycl',
'-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '',
]
_SYCL_DLINK_FLAGS = [
*_COMMON_SYCL_FLAGS,
'-fsycl-link',
'--offload-compress',
f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '',
]
JIT_EXTENSION_VERSIONER = ExtensionVersioner()
@ -812,6 +821,7 @@ class BuildExtension(build_ext):
sycl_dlink_post_cflags = None
if with_sycl:
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
sycl_cflags += _get_sycl_target_flags()
if isinstance(extra_postargs, dict):
sycl_post_cflags = extra_postargs['sycl']
else:
@ -829,6 +839,7 @@ class BuildExtension(build_ext):
sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
sycl_dlink_post_cflags += _get_sycl_device_flags()
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
_write_ninja_file_and_compile_objects(
@ -2688,6 +2699,7 @@ def _write_ninja_file_to_build_library(path,
if with_sycl:
sycl_cflags = cflags + _COMMON_SYCL_FLAGS
sycl_cflags += _get_sycl_target_flags()
sycl_cflags += extra_sycl_cflags
_append_sycl_std_if_no_std_present(sycl_cflags)
host_cflags = cflags
@ -2696,6 +2708,7 @@ def _write_ninja_file_to_build_library(path,
host_cflags = ' '.join(host_cflags)
sycl_cflags += _wrap_sycl_host_flags(host_cflags)
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
sycl_dlink_post_cflags += _get_sycl_device_flags()
else:
sycl_cflags = None
sycl_dlink_post_cflags = None