[Intel GPU][pre_compile] Add XPU toolkit version and hardware info in compiled model check. (#162951)

Following #162438, this PR generalized the origin CUDA only check, and add XPU check.

Fixes #162939, Fixes #162938, Fixes #163032,Fixes #163045

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162951
Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
xinan.lin 2025-09-18 00:04:22 +00:00 committed by PyTorch MergeBot
parent 26eefd5ae2
commit e93706c2c8
5 changed files with 50 additions and 46 deletions

View File

@ -12,7 +12,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
import torch import torch
import torch.fx import torch.fx
from torch._dynamo.graph_utils import _graph_uses_non_cpu from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.precompile_context import SystemInfo from torch._dynamo.precompile_context import SystemInfo
from . import convert_frame from . import convert_frame
@ -59,12 +59,12 @@ class CompileArtifacts:
original_code: types.CodeType original_code: types.CodeType
closure: Optional[tuple[Any, ...]] closure: Optional[tuple[Any, ...]]
source_info: "SourceInfo" source_info: "SourceInfo"
use_cuda: bool device_type: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current) system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
def check_compatibility(self) -> None: def check_compatibility(self) -> None:
current_system = SystemInfo.current() current_system = SystemInfo.current()
current_system.check_compatibility(self.system_info, self.use_cuda) current_system.check_compatibility(self.system_info, self.device_type)
@dataclass @dataclass
@ -266,7 +266,7 @@ def aot_compile_fullgraph(
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment] backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
output_graph = dynamo_output.tracer_output.output_graph output_graph = dynamo_output.tracer_output.output_graph
assert output_graph is not None assert output_graph is not None
use_cuda = _graph_uses_non_cpu(output_graph.current_tracer.graph) device_type = _graph_device_type(output_graph.current_tracer.graph)
import_sources = output_graph.import_sources import_sources = output_graph.import_sources
with ( with (
torch._guards.tracing(TracingContext(backend_input.fake_mode)), torch._guards.tracing(TracingContext(backend_input.fake_mode)),
@ -310,7 +310,7 @@ def aot_compile_fullgraph(
original_code=fn.__code__, original_code=fn.__code__,
closure=fn.__closure__, closure=fn.__closure__,
source_info=source_info, source_info=source_info,
use_cuda=use_cuda, device_type=device_type,
) )
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts) aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)

View File

@ -1264,7 +1264,7 @@ def _compile(
assert check_fn.guards_state is not None assert check_fn.guards_state is not None
package.add_guarded_code(check_fn.guards_state, out_code) package.add_guarded_code(check_fn.guards_state, out_code)
package.add_inlined_source(output.tracing_context.traced_code) package.add_inlined_source(output.tracing_context.traced_code)
package.update_use_cuda(output.current_tracer.graph) package.update_device_type(output.current_tracer.graph)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown" compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str annotation_str = "Torch-Compiled Region: " + compile_id_str

View File

@ -79,16 +79,16 @@ def _detect_cycles(
return "no cycle detected" return "no cycle detected"
def _graph_uses_non_cpu(graph: Optional[Graph]) -> bool: def _graph_device_type(graph: Optional[Graph]) -> str:
if graph is None: if graph is None:
return False return "cpu"
def _is_non_cpu(x: Any) -> bool: def _device_type(x: Any) -> str:
if isinstance(x, torch.device): if isinstance(x, torch.device):
return x.type != "cpu" return x.type
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
return x.device.type != "cpu" return x.device.type
return False return "cpu"
def _flatten_meta(node: Node, key: str) -> list[Any]: def _flatten_meta(node: Node, key: str) -> list[Any]:
if key not in node.meta: if key not in node.meta:
@ -99,19 +99,18 @@ def _graph_uses_non_cpu(graph: Optional[Graph]) -> bool:
for node in graph.nodes: for node in graph.nodes:
for key in ("val", "example_value"): for key in ("val", "example_value"):
for obj in _flatten_meta(node, key): for obj in _flatten_meta(node, key):
if _is_non_cpu(obj): return _device_type(obj)
return True
# Check for device conversions # Check for device conversions
if node.op == "call_method": if node.op == "call_method":
if node.target == "cuda": for gpu in ["cuda", "xpu"]:
return True if node.target == gpu:
if node.target == "to" and "cuda" in node.args: return gpu
return True if node.target == "to" and gpu in node.args:
return gpu
# Check args/kwargs for non-CPU device specs # Check args/kwargs for non-CPU device specs
flat_args, _ = tree_flatten((node.args, node.kwargs)) flat_args, _ = tree_flatten((node.args, node.kwargs))
for obj in flat_args: for obj in flat_args:
if _is_non_cpu(obj): return _device_type(obj)
return True return "cpu"
return False

View File

@ -29,7 +29,7 @@ from typing_extensions import Never
import torch import torch
import torch._inductor.package import torch._inductor.package
from torch._dynamo.exc import PackageError from torch._dynamo.exc import PackageError
from torch._dynamo.graph_utils import _graph_uses_non_cpu from torch._dynamo.graph_utils import _graph_device_type
from torch._dynamo.precompile_context import ( from torch._dynamo.precompile_context import (
PrecompileCacheArtifact, PrecompileCacheArtifact,
PrecompileContext, PrecompileContext,
@ -308,7 +308,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
class _DynamoCacheEntry: class _DynamoCacheEntry:
codes: list[_DynamoCodeCacheEntry] codes: list[_DynamoCodeCacheEntry]
source_info: SourceInfo source_info: SourceInfo
use_cuda: bool device_type: str
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current) system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
@property @property
@ -318,7 +318,7 @@ class _DynamoCacheEntry:
def check_versions(self) -> None: def check_versions(self) -> None:
"""Check if the current system is compatible with the system used to create this cache entry.""" """Check if the current system is compatible with the system used to create this cache entry."""
current_system_info = SystemInfo.current() current_system_info = SystemInfo.current()
self.system_info.check_compatibility(current_system_info, self.use_cuda) self.system_info.check_compatibility(current_system_info, self.device_type)
@CacheArtifactFactory.register @CacheArtifactFactory.register
@ -407,8 +407,8 @@ class CompilePackage:
self._current_entry: Optional[_DynamoCodeCacheEntry] = None self._current_entry: Optional[_DynamoCodeCacheEntry] = None
self._installed_globals: dict[types.ModuleType, list[str]] = {} self._installed_globals: dict[types.ModuleType, list[str]] = {}
# whether cuda is used # device_type that model compiled with.
self._use_cuda = False self._device_type = "cpu"
# For debugging/testing purpose only. # For debugging/testing purpose only.
self._cached_backends: dict[_BackendId, Any] = {} self._cached_backends: dict[_BackendId, Any] = {}
@ -553,8 +553,8 @@ class CompilePackage:
continue continue
self._source_info.add_code(code) self._source_info.add_code(code)
def update_use_cuda(self, graph: Optional[torch.fx.Graph]) -> None: def update_device_type(self, graph: Optional[torch.fx.Graph]) -> None:
self._use_cuda = _graph_uses_non_cpu(graph) self._device_type = _graph_device_type(graph)
def bypass_current_entry(self) -> None: def bypass_current_entry(self) -> None:
assert self._current_entry is not None assert self._current_entry is not None
@ -694,7 +694,7 @@ class CompilePackage:
return _DynamoCacheEntry( return _DynamoCacheEntry(
codes=list(self._codes.values()), codes=list(self._codes.values()),
source_info=self._source_info, source_info=self._source_info,
use_cuda=self._use_cuda, device_type=self._device_type,
) )
@staticmethod @staticmethod

View File

@ -259,31 +259,36 @@ class SystemInfo:
python_version: str python_version: str
torch_version: str torch_version: str
cuda_version: Optional[str] toolkit_version: Optional[str]
triton_version: Optional[tuple[int, int]] triton_version: Optional[tuple[int, int]]
gpu_name: Optional[str] gpu_name: Optional[str]
CHECK_GPUS = ("cuda", "xpu")
@classmethod @classmethod
def current(cls) -> "SystemInfo": def current(cls) -> "SystemInfo":
"""Create a SystemInfo instance with current system information.""" """Create a SystemInfo instance with current system information."""
# Get GPU name if CUDA is available # Get GPU name if CUDA or XPU is available
gpu_name = None gpu_name, toolkit_version = None, None
if torch.cuda.is_available(): for device_type in cls.CHECK_GPUS:
if getattr(torch, device_type).is_available():
try: try:
gpu_name = torch.cuda.get_device_name() gpu_name = getattr(torch, device_type).get_device_name()
toolkit_version = getattr(torch.version, device_type)
break
except Exception: except Exception:
# If we can't get GPU info, leave as None
pass pass
return cls( return cls(
python_version=platform.python_version(), python_version=platform.python_version(),
torch_version=torch.__version__, torch_version=torch.__version__,
cuda_version=torch.version.cuda, toolkit_version=toolkit_version,
triton_version=get_triton_version((0, 0)), triton_version=get_triton_version((0, 0)),
gpu_name=gpu_name, gpu_name=gpu_name,
) )
def check_compatibility(self, other: "SystemInfo", use_cuda: bool = False) -> None: def check_compatibility(
self, other: "SystemInfo", device_type: str = "cpu"
) -> None:
""" """
Check if this SystemInfo is compatible with another SystemInfo. Check if this SystemInfo is compatible with another SystemInfo.
Raises RuntimeError if incompatible. Raises RuntimeError if incompatible.
@ -297,13 +302,13 @@ class SystemInfo:
raise RuntimeError( raise RuntimeError(
f"Compile package was created with a different PyTorch version: {self.torch_version}" f"Compile package was created with a different PyTorch version: {self.torch_version}"
) )
if device_type in self.CHECK_GPUS:
if not getattr(torch, device_type).is_available():
raise RuntimeError(f"{device_type} is not available")
if use_cuda: if self.toolkit_version != other.toolkit_version:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
if self.cuda_version != other.cuda_version:
raise RuntimeError( raise RuntimeError(
f"Compile package was created with a different CUDA version: {self.cuda_version}" f"Compile package was created with a different toolkit version: {self.toolkit_version}"
) )
if ( if (
@ -314,7 +319,7 @@ class SystemInfo:
f"Compile package was created with a different Triton version: {self.triton_version}" f"Compile package was created with a different Triton version: {self.triton_version}"
) )
# Check GPU name if CUDA was used # Check GPU name if CUDA/XPU was used
if other.gpu_name is not None and self.gpu_name != other.gpu_name: if other.gpu_name is not None and self.gpu_name != other.gpu_name:
raise RuntimeError( raise RuntimeError(
f"Compile package was created with different GPU: " f"Compile package was created with different GPU: "