mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Intel GPU][pre_compile] Add XPU toolkit version and hardware info in compiled model check. (#162951)
Following #162438, this PR generalized the origin CUDA only check, and add XPU check. Fixes #162939, Fixes #162938, Fixes #163032,Fixes #163045 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162951 Approved by: https://github.com/EikanWang, https://github.com/jansel
This commit is contained in:
parent
26eefd5ae2
commit
e93706c2c8
|
|
@ -12,7 +12,7 @@ from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch._dynamo.graph_utils import _graph_uses_non_cpu
|
from torch._dynamo.graph_utils import _graph_device_type
|
||||||
from torch._dynamo.precompile_context import SystemInfo
|
from torch._dynamo.precompile_context import SystemInfo
|
||||||
|
|
||||||
from . import convert_frame
|
from . import convert_frame
|
||||||
|
|
@ -59,12 +59,12 @@ class CompileArtifacts:
|
||||||
original_code: types.CodeType
|
original_code: types.CodeType
|
||||||
closure: Optional[tuple[Any, ...]]
|
closure: Optional[tuple[Any, ...]]
|
||||||
source_info: "SourceInfo"
|
source_info: "SourceInfo"
|
||||||
use_cuda: bool
|
device_type: str
|
||||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||||
|
|
||||||
def check_compatibility(self) -> None:
|
def check_compatibility(self) -> None:
|
||||||
current_system = SystemInfo.current()
|
current_system = SystemInfo.current()
|
||||||
current_system.check_compatibility(self.system_info, self.use_cuda)
|
current_system.check_compatibility(self.system_info, self.device_type)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -266,7 +266,7 @@ def aot_compile_fullgraph(
|
||||||
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
backend_input.graph_module._backend_id = backend_input.backend_id # type: ignore[assignment]
|
||||||
output_graph = dynamo_output.tracer_output.output_graph
|
output_graph = dynamo_output.tracer_output.output_graph
|
||||||
assert output_graph is not None
|
assert output_graph is not None
|
||||||
use_cuda = _graph_uses_non_cpu(output_graph.current_tracer.graph)
|
device_type = _graph_device_type(output_graph.current_tracer.graph)
|
||||||
import_sources = output_graph.import_sources
|
import_sources = output_graph.import_sources
|
||||||
with (
|
with (
|
||||||
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
torch._guards.tracing(TracingContext(backend_input.fake_mode)),
|
||||||
|
|
@ -310,7 +310,7 @@ def aot_compile_fullgraph(
|
||||||
original_code=fn.__code__,
|
original_code=fn.__code__,
|
||||||
closure=fn.__closure__,
|
closure=fn.__closure__,
|
||||||
source_info=source_info,
|
source_info=source_info,
|
||||||
use_cuda=use_cuda,
|
device_type=device_type,
|
||||||
)
|
)
|
||||||
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1264,7 +1264,7 @@ def _compile(
|
||||||
assert check_fn.guards_state is not None
|
assert check_fn.guards_state is not None
|
||||||
package.add_guarded_code(check_fn.guards_state, out_code)
|
package.add_guarded_code(check_fn.guards_state, out_code)
|
||||||
package.add_inlined_source(output.tracing_context.traced_code)
|
package.add_inlined_source(output.tracing_context.traced_code)
|
||||||
package.update_use_cuda(output.current_tracer.graph)
|
package.update_device_type(output.current_tracer.graph)
|
||||||
|
|
||||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||||
|
|
|
||||||
|
|
@ -79,16 +79,16 @@ def _detect_cycles(
|
||||||
return "no cycle detected"
|
return "no cycle detected"
|
||||||
|
|
||||||
|
|
||||||
def _graph_uses_non_cpu(graph: Optional[Graph]) -> bool:
|
def _graph_device_type(graph: Optional[Graph]) -> str:
|
||||||
if graph is None:
|
if graph is None:
|
||||||
return False
|
return "cpu"
|
||||||
|
|
||||||
def _is_non_cpu(x: Any) -> bool:
|
def _device_type(x: Any) -> str:
|
||||||
if isinstance(x, torch.device):
|
if isinstance(x, torch.device):
|
||||||
return x.type != "cpu"
|
return x.type
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
return x.device.type != "cpu"
|
return x.device.type
|
||||||
return False
|
return "cpu"
|
||||||
|
|
||||||
def _flatten_meta(node: Node, key: str) -> list[Any]:
|
def _flatten_meta(node: Node, key: str) -> list[Any]:
|
||||||
if key not in node.meta:
|
if key not in node.meta:
|
||||||
|
|
@ -99,19 +99,18 @@ def _graph_uses_non_cpu(graph: Optional[Graph]) -> bool:
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
for key in ("val", "example_value"):
|
for key in ("val", "example_value"):
|
||||||
for obj in _flatten_meta(node, key):
|
for obj in _flatten_meta(node, key):
|
||||||
if _is_non_cpu(obj):
|
return _device_type(obj)
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for device conversions
|
# Check for device conversions
|
||||||
if node.op == "call_method":
|
if node.op == "call_method":
|
||||||
if node.target == "cuda":
|
for gpu in ["cuda", "xpu"]:
|
||||||
return True
|
if node.target == gpu:
|
||||||
if node.target == "to" and "cuda" in node.args:
|
return gpu
|
||||||
return True
|
if node.target == "to" and gpu in node.args:
|
||||||
|
return gpu
|
||||||
|
|
||||||
# Check args/kwargs for non-CPU device specs
|
# Check args/kwargs for non-CPU device specs
|
||||||
flat_args, _ = tree_flatten((node.args, node.kwargs))
|
flat_args, _ = tree_flatten((node.args, node.kwargs))
|
||||||
for obj in flat_args:
|
for obj in flat_args:
|
||||||
if _is_non_cpu(obj):
|
return _device_type(obj)
|
||||||
return True
|
return "cpu"
|
||||||
return False
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ from typing_extensions import Never
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.package
|
import torch._inductor.package
|
||||||
from torch._dynamo.exc import PackageError
|
from torch._dynamo.exc import PackageError
|
||||||
from torch._dynamo.graph_utils import _graph_uses_non_cpu
|
from torch._dynamo.graph_utils import _graph_device_type
|
||||||
from torch._dynamo.precompile_context import (
|
from torch._dynamo.precompile_context import (
|
||||||
PrecompileCacheArtifact,
|
PrecompileCacheArtifact,
|
||||||
PrecompileContext,
|
PrecompileContext,
|
||||||
|
|
@ -308,7 +308,7 @@ def _get_code_source(code: types.CodeType) -> tuple[str, str]:
|
||||||
class _DynamoCacheEntry:
|
class _DynamoCacheEntry:
|
||||||
codes: list[_DynamoCodeCacheEntry]
|
codes: list[_DynamoCodeCacheEntry]
|
||||||
source_info: SourceInfo
|
source_info: SourceInfo
|
||||||
use_cuda: bool
|
device_type: str
|
||||||
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
system_info: SystemInfo = dataclasses.field(default_factory=SystemInfo.current)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -318,7 +318,7 @@ class _DynamoCacheEntry:
|
||||||
def check_versions(self) -> None:
|
def check_versions(self) -> None:
|
||||||
"""Check if the current system is compatible with the system used to create this cache entry."""
|
"""Check if the current system is compatible with the system used to create this cache entry."""
|
||||||
current_system_info = SystemInfo.current()
|
current_system_info = SystemInfo.current()
|
||||||
self.system_info.check_compatibility(current_system_info, self.use_cuda)
|
self.system_info.check_compatibility(current_system_info, self.device_type)
|
||||||
|
|
||||||
|
|
||||||
@CacheArtifactFactory.register
|
@CacheArtifactFactory.register
|
||||||
|
|
@ -407,8 +407,8 @@ class CompilePackage:
|
||||||
|
|
||||||
self._current_entry: Optional[_DynamoCodeCacheEntry] = None
|
self._current_entry: Optional[_DynamoCodeCacheEntry] = None
|
||||||
self._installed_globals: dict[types.ModuleType, list[str]] = {}
|
self._installed_globals: dict[types.ModuleType, list[str]] = {}
|
||||||
# whether cuda is used
|
# device_type that model compiled with.
|
||||||
self._use_cuda = False
|
self._device_type = "cpu"
|
||||||
|
|
||||||
# For debugging/testing purpose only.
|
# For debugging/testing purpose only.
|
||||||
self._cached_backends: dict[_BackendId, Any] = {}
|
self._cached_backends: dict[_BackendId, Any] = {}
|
||||||
|
|
@ -553,8 +553,8 @@ class CompilePackage:
|
||||||
continue
|
continue
|
||||||
self._source_info.add_code(code)
|
self._source_info.add_code(code)
|
||||||
|
|
||||||
def update_use_cuda(self, graph: Optional[torch.fx.Graph]) -> None:
|
def update_device_type(self, graph: Optional[torch.fx.Graph]) -> None:
|
||||||
self._use_cuda = _graph_uses_non_cpu(graph)
|
self._device_type = _graph_device_type(graph)
|
||||||
|
|
||||||
def bypass_current_entry(self) -> None:
|
def bypass_current_entry(self) -> None:
|
||||||
assert self._current_entry is not None
|
assert self._current_entry is not None
|
||||||
|
|
@ -694,7 +694,7 @@ class CompilePackage:
|
||||||
return _DynamoCacheEntry(
|
return _DynamoCacheEntry(
|
||||||
codes=list(self._codes.values()),
|
codes=list(self._codes.values()),
|
||||||
source_info=self._source_info,
|
source_info=self._source_info,
|
||||||
use_cuda=self._use_cuda,
|
device_type=self._device_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -259,31 +259,36 @@ class SystemInfo:
|
||||||
|
|
||||||
python_version: str
|
python_version: str
|
||||||
torch_version: str
|
torch_version: str
|
||||||
cuda_version: Optional[str]
|
toolkit_version: Optional[str]
|
||||||
triton_version: Optional[tuple[int, int]]
|
triton_version: Optional[tuple[int, int]]
|
||||||
gpu_name: Optional[str]
|
gpu_name: Optional[str]
|
||||||
|
CHECK_GPUS = ("cuda", "xpu")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def current(cls) -> "SystemInfo":
|
def current(cls) -> "SystemInfo":
|
||||||
"""Create a SystemInfo instance with current system information."""
|
"""Create a SystemInfo instance with current system information."""
|
||||||
# Get GPU name if CUDA is available
|
# Get GPU name if CUDA or XPU is available
|
||||||
gpu_name = None
|
gpu_name, toolkit_version = None, None
|
||||||
if torch.cuda.is_available():
|
for device_type in cls.CHECK_GPUS:
|
||||||
|
if getattr(torch, device_type).is_available():
|
||||||
try:
|
try:
|
||||||
gpu_name = torch.cuda.get_device_name()
|
gpu_name = getattr(torch, device_type).get_device_name()
|
||||||
|
toolkit_version = getattr(torch.version, device_type)
|
||||||
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
# If we can't get GPU info, leave as None
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
python_version=platform.python_version(),
|
python_version=platform.python_version(),
|
||||||
torch_version=torch.__version__,
|
torch_version=torch.__version__,
|
||||||
cuda_version=torch.version.cuda,
|
toolkit_version=toolkit_version,
|
||||||
triton_version=get_triton_version((0, 0)),
|
triton_version=get_triton_version((0, 0)),
|
||||||
gpu_name=gpu_name,
|
gpu_name=gpu_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_compatibility(self, other: "SystemInfo", use_cuda: bool = False) -> None:
|
def check_compatibility(
|
||||||
|
self, other: "SystemInfo", device_type: str = "cpu"
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Check if this SystemInfo is compatible with another SystemInfo.
|
Check if this SystemInfo is compatible with another SystemInfo.
|
||||||
Raises RuntimeError if incompatible.
|
Raises RuntimeError if incompatible.
|
||||||
|
|
@ -297,13 +302,13 @@ class SystemInfo:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Compile package was created with a different PyTorch version: {self.torch_version}"
|
f"Compile package was created with a different PyTorch version: {self.torch_version}"
|
||||||
)
|
)
|
||||||
|
if device_type in self.CHECK_GPUS:
|
||||||
|
if not getattr(torch, device_type).is_available():
|
||||||
|
raise RuntimeError(f"{device_type} is not available")
|
||||||
|
|
||||||
if use_cuda:
|
if self.toolkit_version != other.toolkit_version:
|
||||||
if not torch.cuda.is_available():
|
|
||||||
raise RuntimeError("CUDA is not available")
|
|
||||||
if self.cuda_version != other.cuda_version:
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Compile package was created with a different CUDA version: {self.cuda_version}"
|
f"Compile package was created with a different toolkit version: {self.toolkit_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|
@ -314,7 +319,7 @@ class SystemInfo:
|
||||||
f"Compile package was created with a different Triton version: {self.triton_version}"
|
f"Compile package was created with a different Triton version: {self.triton_version}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check GPU name if CUDA was used
|
# Check GPU name if CUDA/XPU was used
|
||||||
if other.gpu_name is not None and self.gpu_name != other.gpu_name:
|
if other.gpu_name is not None and self.gpu_name != other.gpu_name:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Compile package was created with different GPU: "
|
f"Compile package was created with different GPU: "
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user