mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
# Motivation
@jansel As discussed before, we expected to generalize some cuda-specific code. This can make inductor more friendly to third-party backend so that we can leverage inductor code as much as possible.
# Solution
To implement this, we give a solution to introduce device runtime abstraction. We wrapper them inside `DeviceInterface` and use `register_interface_for_device` to register each kind of device to inductor. Then use `get_interface_for_device` to fetch the corresponding runtime from device type. Then usage is like this:
```python
device_interface = get_interface_for_device("xpu")
device_interface .is_available() # to check if XPU is available
device_interface .device_count() # to check how much XPU device is available
```
The `DeviceInterface` is a simple abstraction, which enables third-party backends that implement CUDA-like semantics to be integrated with inductor. This can prevent third-party backend from using monkey patch to override some utility functions, like `decode_device` that is hard-coded with CUDA.
# Additional Context
The main code change:
- To leverage AsyncCompile, make it device-agnostic
- Avoid monkey patches, make some utility functions device-agnostic
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109486
Approved by: https://github.com/jansel, https://github.com/jgong5, https://github.com/EikanWang
26 lines
633 B
Python
26 lines
633 B
Python
import functools
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
import triton
|
|
|
|
return triton is not None
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton() -> bool:
|
|
def is_cuda_compatible_with_triton():
|
|
device_interface = get_interface_for_device("cuda")
|
|
return (
|
|
device_interface.is_available()
|
|
and device_interface.Worker.get_device_properties().major >= 7
|
|
)
|
|
|
|
return is_cuda_compatible_with_triton() and has_triton_package()
|