diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 3d101519a23..ef4e5b5cb1e 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -3,8 +3,10 @@ import glob import locale import os +import random import re import shutil +import string import subprocess import sys import tempfile @@ -116,11 +118,13 @@ class TestCppExtensionJIT(common.TestCase): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) - @unittest.skipIf(not (TEST_XPU), "XPU not found") - def test_jit_xpu_extension(self): - # NOTE: The name of the extension must equal the name of the module. + def _test_jit_xpu_extension(self): + name = "torch_test_xpu_extension_" + # randomizing name for the case when we test building few extensions + # in a row using this function + name += "".join(random.sample(string.ascii_letters, 5)) module = torch.utils.cpp_extension.load( - name="torch_test_xpu_extension", + name=name, sources=[ "cpp_extensions/xpu_extension.sycl", ], @@ -136,6 +140,31 @@ class TestCppExtensionJIT(common.TestCase): # 2 * sigmoid(0) = 2 * 0.5 = 1 self.assertEqual(z, torch.ones_like(z)) + @unittest.skipIf(not (TEST_XPU), "XPU not found") + def test_jit_xpu_extension(self): + # NOTE: this test can be affected by setting TORCH_XPU_ARCH_LIST + self._test_jit_xpu_extension() + + @unittest.skipIf(not (TEST_XPU), "XPU not found") + def test_jit_xpu_archlists(self): + # NOTE: in this test we explicitly test few different options + # for TORCH_XPU_ARCH_LIST. Setting TORCH_XPU_ARCH_LIST in the + # environment before the test won't affect it. + archlists = [ + "", # expecting JIT compilation + ",".join(torch.xpu.get_arch_list()), + ] + old_envvar = os.environ.get("TORCH_XPU_ARCH_LIST", None) + try: + for al in archlists: + os.environ["TORCH_XPU_ARCH_LIST"] = al + self._test_jit_xpu_extension() + finally: + if old_envvar is None: + os.environ.pop("TORCH_XPU_ARCH_LIST") + else: + os.environ["TORCH_XPU_ARCH_LIST"] = old_envvar + @unittest.skipIf(not TEST_MPS, "MPS not found") def test_mps_extension(self): module = torch.utils.cpp_extension.load( diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 50601fa9b68..7697b9394a3 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -291,16 +291,25 @@ def _get_sycl_arch_list(): # If arch list returned by _get_sycl_arch_list() is empty, then sycl kernels will be compiled # for default spir64 target and avoid device specific compilations entirely. Further, kernels # will be JIT compiled at runtime. +def _get_sycl_target_flags(): + if _get_sycl_arch_list() != '': + return ['-fsycl-targets=spir64_gen,spir64'] + return [''] + +def _get_sycl_device_flags(): + arch_list = _get_sycl_arch_list() + if arch_list != '': + return [f'-Xs "-device {arch_list}"'] + return [''] + _COMMON_SYCL_FLAGS = [ '-fsycl', - '-fsycl-targets=spir64_gen,spir64' if _get_sycl_arch_list() != '' else '', ] _SYCL_DLINK_FLAGS = [ *_COMMON_SYCL_FLAGS, '-fsycl-link', '--offload-compress', - f'-Xs "-device {_get_sycl_arch_list()}"' if _get_sycl_arch_list() != '' else '', ] JIT_EXTENSION_VERSIONER = ExtensionVersioner() @@ -812,6 +821,7 @@ class BuildExtension(build_ext): sycl_dlink_post_cflags = None if with_sycl: sycl_cflags = extra_cc_cflags + common_cflags + _COMMON_SYCL_FLAGS + sycl_cflags += _get_sycl_target_flags() if isinstance(extra_postargs, dict): sycl_post_cflags = extra_postargs['sycl'] else: @@ -829,6 +839,7 @@ class BuildExtension(build_ext): sycl_cflags = [shlex.quote(f) for f in sycl_cflags] sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS + sycl_dlink_post_cflags += _get_sycl_device_flags() sycl_post_cflags = [shlex.quote(f) for f in sycl_post_cflags] _write_ninja_file_and_compile_objects( @@ -2688,6 +2699,7 @@ def _write_ninja_file_to_build_library(path, if with_sycl: sycl_cflags = cflags + _COMMON_SYCL_FLAGS + sycl_cflags += _get_sycl_target_flags() sycl_cflags += extra_sycl_cflags _append_sycl_std_if_no_std_present(sycl_cflags) host_cflags = cflags @@ -2696,6 +2708,7 @@ def _write_ninja_file_to_build_library(path, host_cflags = ' '.join(host_cflags) sycl_cflags += _wrap_sycl_host_flags(host_cflags) sycl_dlink_post_cflags = _SYCL_DLINK_FLAGS + sycl_dlink_post_cflags += _get_sycl_device_flags() else: sycl_cflags = None sycl_dlink_post_cflags = None