pytorch/torch/utils/_triton.py
Yu, Guangye e9c9b1ed59 [Inductor] Generalize inductor triton backend device agnostic (#109486)
# Motivation
@jansel As discussed before, we expected to generalize some cuda-specific code. This can make inductor more friendly to third-party backend so that we can leverage inductor code as much as possible.

# Solution
To implement this, we give a solution to introduce device runtime abstraction. We wrapper them inside `DeviceInterface` and use `register_interface_for_device` to register each kind of device to inductor. Then use `get_interface_for_device` to fetch the corresponding runtime from device type. Then usage is like this:
```python
device_interface = get_interface_for_device("xpu")
device_interface .is_available() # to check if XPU is available
device_interface .device_count() # to check how much XPU device is available
```
The `DeviceInterface` is a simple abstraction, which enables third-party backends that implement CUDA-like semantics to be integrated with inductor. This can prevent third-party backend from using monkey patch to override some utility functions, like `decode_device` that is hard-coded with CUDA.

# Additional Context
The main code change:
- To leverage AsyncCompile, make it device-agnostic
- Avoid monkey patches, make some utility functions device-agnostic

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109486
Approved by: https://github.com/jansel, https://github.com/jgong5, https://github.com/EikanWang
2023-09-24 07:49:20 +00:00

26 lines
633 B
Python

import functools
from torch._dynamo.device_interface import get_interface_for_device
@functools.lru_cache(None)
def has_triton_package() -> bool:
try:
import triton
return triton is not None
except ImportError:
return False
@functools.lru_cache(None)
def has_triton() -> bool:
def is_cuda_compatible_with_triton():
device_interface = get_interface_for_device("cuda")
return (
device_interface.is_available()
and device_interface.Worker.get_device_properties().major >= 7
)
return is_cuda_compatible_with_triton() and has_triton_package()