mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This adds Dynamo tracing support for the host-side Triton TMA API (see `create_2d_tma_descriptor` calls on the host in the [Triton tutorial](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html#sphx-glr-getting-started-tutorials-09-persistent-matmul-py)). A few notes: - Here we assume the availability of the host-side TMA API added to upstream Triton in https://github.com/triton-lang/triton/pull/4498. As of time of writing, this is not a part of the PT2 OSS Triton pin (although back-ported internally). OSS Triton pin update should be done in December 2024. - To capture the chain of calls `t.data_ptr() --> create_{1d,2d}_tma_descriptor(ptr, ...) --> kernel[grid](tma_desc, ...)`, we add three new variable trackers: `DataPtrVariable`, `CreateTMADescriptorVariable` (for the function), `TMADescriptorVariable` (for TMA descriptor object). This is to maintain the path back from the Triton kernel to the Tensor from which the TMA descriptor has been created. - The newly introduced variables have `reconstruct` methods used in case of graph breaks. - The `tma_descriptor_metadata` extracted from the captured `create_{1d,2d}_tma_descriptor` calls is propagated through the HOPs in Dynamo and AOTAutograd to be used by the downstream compiler (e.g., Inductor). See the unit tests for how the captured HOP arguments look like. - In the Dynamo-captured fx graph, we replace the TMA descriptor arguments of the Triton kernel by the underlying Tensors, to be able to track the input/output relationships in terms of Tensors. - In the Triton kernel mutation analysis pass (in AOTAutograd), we use the `tt.experimental_descriptor_store` TTIR op to detect mutations of the underlying tensors via TMA descriptors. So that downstream AOTAutograd can perform functionalizations as required. - JIT Inductor and AOT Inductor support will be implemented in follow-up PRs. Differential Revision: [D64404928](https://our.internmc.facebook.com/intern/diff/D64404928) Pull Request resolved: https://github.com/pytorch/pytorch/pull/137677 Approved by: https://github.com/zou3519
113 lines
3.0 KiB
Python
113 lines
3.0 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import hashlib
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
return triton_key is not None
|
|
except ImportError:
|
|
return False
|
|
except RuntimeError:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_tma():
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
):
|
|
try:
|
|
from triton.tools.experimental_descriptor import ( # noqa: F401
|
|
create_1d_tma_descriptor,
|
|
create_2d_tma_descriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton() -> bool:
|
|
if not has_triton_package():
|
|
return False
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface):
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def cpu_extra_check(device_interface):
|
|
import triton.backends
|
|
|
|
return "cpu" in triton.backends.backends
|
|
|
|
def _return_true(device_interface):
|
|
return True
|
|
|
|
triton_supported_devices = {
|
|
"cuda": cuda_extra_check,
|
|
"xpu": _return_true,
|
|
"cpu": cpu_extra_check,
|
|
}
|
|
|
|
def is_device_compatible_with_triton():
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton()
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_backend():
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_hash_with_backend():
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|
|
|
|
|
|
def dtype_to_string(dtype):
|
|
if dtype.name.startswith("fp"):
|
|
suffix = "float" + dtype.name[2:]
|
|
elif dtype.name.startswith("bf"):
|
|
suffix = "bfloat" + dtype.name[2:]
|
|
else:
|
|
suffix = dtype.name
|
|
return "triton.language." + suffix
|
|
|
|
|
|
def patch_triton_dtype_repr():
|
|
import triton
|
|
|
|
# Hack to get triton dtype repr to produce an evaluatable expression
|
|
# triton.language.float32 emits triton.language.fp32 which does not
|
|
# exist
|
|
# REMOVE when https://github.com/openai/triton/pull/3342 lands
|
|
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
|