mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly. The changes include: - Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc. - Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc. - MTIA specific codegen logic, for example, loading MTIA dynamic_library. - Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU. - Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend. - A change in Inductor runtime to avoid re-initialize MTIADriver. - BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag. - Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag. - Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose. Note: - This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead. - MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen. Internal: References: - [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/) - [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb) - [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w) - [early prototying diff](https://www.internalfb.com/diff/D75110196) - [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959) - [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678) Differential Revision: [D78458745](https://our.internmc.facebook.com/intern/diff/D78458745/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158526 Approved by: https://github.com/blaine-rister, https://github.com/jansel, https://github.com/eellison
169 lines
4.2 KiB
Python
169 lines
4.2 KiB
Python
import functools
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def _device_supports_tma() -> bool:
|
|
import torch
|
|
|
|
return (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_experimental_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.experimental_descriptor import ( # noqa: F401
|
|
create_1d_tma_descriptor,
|
|
create_2d_tma_descriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tensor_descriptor_host_tma() -> bool:
|
|
if has_triton_package():
|
|
if _device_supports_tma():
|
|
try:
|
|
from triton.tools.tensor_descriptor import ( # noqa: F401
|
|
TensorDescriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma() -> bool:
|
|
return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
|
|
|
|
|
|
@functools.cache
|
|
def has_triton_tma_device() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
):
|
|
# old API
|
|
try:
|
|
from triton.language.extra.cuda import ( # noqa: F401
|
|
experimental_device_tensormap_create1d,
|
|
experimental_device_tensormap_create2d,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
# new API
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_stable_tma_api() -> bool:
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
):
|
|
try:
|
|
from triton.language import make_tensor_descriptor # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
return False
|
|
|
|
|
|
@functools.cache
|
|
def has_triton() -> bool:
|
|
if not has_triton_package():
|
|
return False
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface: Any) -> bool:
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def cpu_extra_check(device_interface: Any) -> bool:
|
|
import triton.backends
|
|
|
|
return "cpu" in triton.backends.backends
|
|
|
|
def _return_true(device_interface: Any) -> bool:
|
|
return True
|
|
|
|
triton_supported_devices = {
|
|
"cuda": cuda_extra_check,
|
|
"xpu": _return_true,
|
|
"cpu": cpu_extra_check,
|
|
"mtia": _return_true,
|
|
}
|
|
|
|
def is_device_compatible_with_triton() -> bool:
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton()
|
|
|
|
|
|
@functools.cache
|
|
def triton_backend() -> Any:
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.cache
|
|
def triton_hash_with_backend() -> str:
|
|
from torch._inductor.runtime.triton_compat import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|