xpu: get xpu arch flags at runtime in cpp_extensions (#152192)

This commit moves query for xpu arch flags to runtime when building SYCL extensions which allows to adjust `TORCH_XPU_ARCH_LIST` at python script level. That's handy for example in ci test which gives a try few variants of the list.

CC: @malfet, @jingxu10, @EikanWang, @guangyey

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152192
Approved by: https://github.com/guangyey, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
Dmitry Rogozhkin 2025-05-09 05:43:47 +00:00 committed by PyTorch MergeBot
parent 9fa07340fd
commit aca2c99a65
2 changed files with 48 additions and 6 deletions

View File

@ -3,8 +3,10 @@
import glob import glob
import locale import locale
import os import os
import random
import re import re
import shutil import shutil
import string
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1 # 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z)) self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found") def _test_jit_xpu_extension(self):
def test_jit_xpu_extension(self): name = "torch_test_xpu_extension_"
# NOTE: The name of the extension must equal the name of the module. # randomizing name for the case when we test building few extensions
# in a row using this function
name += "".join(random.sample(string.ascii_letters, 5))
module = torch.utils.cpp_extension.load( module = torch.utils.cpp_extension.load(
name="torch_test_xpu_extension", name=name,
sources=[ sources=[
"cpp_extensions/xpu_extension.sycl", "cpp_extensions/xpu_extension.sycl",
], ],
@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase):
# 2 * sigmoid(0) = 2 * 0.5 = 1 # 2 * sigmoid(0) = 2 * 0.5 = 1
self.assertEqual(z, torch.ones_like(z)) self.assertEqual(z, torch.ones_like(z))
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_extension(self):
# NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST
self._test_jit_xpu_extension()
@unittest.skipIf(not (TEST_XPU), "XPU not found")
def test_jit_xpu_archlists(self):
# NOTE: in this test we explicitly test few different options
# for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the
# environment before the test won't affect it.
archlists = [
"", # expecting JIT compilation
",".join(torch.xpu.get_arch_list()),
]
old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None)
try:
for al in archlists:
os.environ["TORCH_XPU_ARCH_LIST"] = al
self._test_jit_xpu_extension()
finally:
if old_envvar is None:
os.environ.pop("TORCH_XPU_ARCH_LIST")
else:
os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar
@unittest.skipIf(not TEST_MPS, "MPS not found") @unittest.skipIf(not TEST_MPS, "MPS not found")
def test_mps_extension(self): def test_mps_extension(self):
module = torch.utils.cpp_extension.load( module = torch.utils.cpp_extension.load(

View File

@ -291,16 +291,25 @@ def _get_sycl_arch_list():
# If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled # If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled
# for default spir64 target and avoid device specific compilations entirely. Further, kernels # for default spir64 target and avoid device specific compilations entirely. Further, kernels
# will be JIT compiled at runtime. # will be JIT compiled at runtime.
def _get_sycl_target_flags():
if _get_sycl_arch_list() != '':
return ['-fsycl-targets=spir64_gen,spir64']
return ['']
def _get_sycl_device_flags():
arch_list = _get_sycl_arch_list()
if arch_list != '':
return [f'-Xs "-device {arch_list}"']
return ['']
_COMMON_SYCL_FLAGS = [ _COMMON_SYCL_FLAGS = [
'-fsycl', '-fsycl',
'-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '',
] ]
_SYCL_DLINK_FLAGS = [ _SYCL_DLINK_FLAGS = [
*_COMMON_SYCL_FLAGS, *_COMMON_SYCL_FLAGS,
'-fsycl-link', '-fsycl-link',
'--offload-compress', '--offload-compress',
f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '',
] ]
JIT_EXTENSION_VERSIONER = ExtensionVersioner() JIT_EXTENSION_VERSIONER = ExtensionVersioner()
@ -812,6 +821,7 @@ class BuildExtension(build_ext):
sycl_dlink_post_cflags = None sycl_dlink_post_cflags = None
if with_sycl: if with_sycl:
sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS
sycl_cflags += _get_sycl_target_flags()
if isinstance(extra_postargs, dict): if isinstance(extra_postargs, dict):
sycl_post_cflags = extra_postargs['sycl'] sycl_post_cflags = extra_postargs['sycl']
else: else:
@ -829,6 +839,7 @@ class BuildExtension(build_ext):
sycl_cflags = [shlex.quote(f) for f in sycl_cflags] sycl_cflags = [shlex.quote(f) for f in sycl_cflags]
sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_cflags += _wrap_sycl_host_flags(host_cflags)
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
sycl_dlink_post_cflags += _get_sycl_device_flags()
sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags] sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags]
_write_ninja_file_and_compile_objects( _write_ninja_file_and_compile_objects(
@ -2688,6 +2699,7 @@ def _write_ninja_file_to_build_library(path,
if with_sycl: if with_sycl:
sycl_cflags = cflags + _COMMON_SYCL_FLAGS sycl_cflags = cflags + _COMMON_SYCL_FLAGS
sycl_cflags += _get_sycl_target_flags()
sycl_cflags += extra_sycl_cflags sycl_cflags += extra_sycl_cflags
_append_sycl_std_if_no_std_present(sycl_cflags) _append_sycl_std_if_no_std_present(sycl_cflags)
host_cflags = cflags host_cflags = cflags
@ -2696,6 +2708,7 @@ def _write_ninja_file_to_build_library(path,
host_cflags = ' '.join(host_cflags) host_cflags = ' '.join(host_cflags)
sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_cflags += _wrap_sycl_host_flags(host_cflags)
sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS
sycl_dlink_post_cflags += _get_sycl_device_flags()
else: else:
sycl_cflags = None sycl_cflags = None
sycl_dlink_post_cflags = None sycl_dlink_post_cflags = None