mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This diff/PR includes the changes to support native Inductor integration for MTIA. The goal is to support `torch.compile(backend="inductor")` for MTIA. Inductor should generate code(triton kernel + python wrapper code) similar to CUDA. And the triton kernels can be launched eagerly. The changes include: - Add MTIA device interfaces used by Dynamo and Inductor, including APIs on device, stream, event, etc. - Add required torch.mtia APIs, like is_bf16_supported, memory_allocated, set_stream_by_id, etc. - MTIA specific codegen logic, for example, loading MTIA dynamic_library. - Other necessary changes to integrate with Inductor codegn, following other devices like CUDA, XPU. - Integrate with the [empty_strided_mtia](https://www.internalfb.com/code/fbsource/[0d017d3a4a1bdff7253f9c66a9f38e77bd62166b]/fbcode/caffe2/aten/src/ATen/native/mtia/EmptyTensor.cpp?lines=49%2C63%2C71%2C74%2C78) API that we’ve added for the new MTIA ATen backend. - A change in Inductor runtime to avoid re-initialize MTIADriver. - BUCK changes to include ATen-mtia in Inductor, and to use -USE_MTIA preprocessor flag. - Update `test_mnist_e2e.py` to cover native Inductor as backend, using the `--use_native_inductor` flag. - Add a personal script(`scripts/anwang/run_native_inductor_script.py`) for testing purpose. Note: - This approach(option 3) aims to provide a pytorch native approach of Inductor integration for MTIA, minimizing the onboarding overhead. The downside of this approach is that it doesn't leverage MTIA specific graph optimization, and is limited to eagerly launch overhead. - MTIA will support another approach(option 2) to provide best performance, based on WrapperFxCodegen. We should be able to reuse the fundamental changes of this diff for option 2, like the device interfaces, steam/event APIs, etc, especially as WrapperFxCodegen inherits PythonWrapperCodegen. Internal: References: - [post for context](https://fb.workplace.com/groups/mtiasw/permalink/1718377262384606/) - [Inductor integration discussion(option 1/2/3)](https://docs.google.com/document/d/1p6363OXtVIRv1hPoaKlRSK3j-iir3QIbDd5bjyqCNig/edit?tab=t.0#heading=h.7s4ns6wcnhmb) - [Project design doc(option 3)](https://docs.google.com/document/d/1jXUmhgoV9WvkMf-bcY3Od_kK9K_RDOdgHdt1LoQ5Tc4/edit?tab=t.0#heading=h.y43gwdqlv46w) - [early prototying diff](https://www.internalfb.com/diff/D75110196) - [MPS integration PR](https://github.com/pytorch/pytorch/pull/153959) - [empty_strided_xpu PR](https://github.com/pytorch/pytorch/pull/126678) Differential Revision: [D78458745](https://our.internmc.facebook.com/intern/diff/D78458745/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158526 Approved by: https://github.com/blaine-rister, https://github.com/jansel, https://github.com/eellison
224 lines
7.0 KiB
Python
224 lines
7.0 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import collections
|
|
import functools
|
|
import typing
|
|
from enum import auto, Enum
|
|
from typing import Optional, Union
|
|
|
|
from torch.utils._triton import has_triton_package
|
|
|
|
|
|
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
|
# NOTE: if these fail asserts submit a PR to increase them
|
|
TRITON_MAX_BLOCK = {
|
|
"X": 4096,
|
|
"Y": 1024,
|
|
"Z": 1024,
|
|
"R0_": 4096 * 16, # * 16 is multi-kernel only
|
|
"R1_": 2048 * 16, # * 16 is multi-kernel only
|
|
}
|
|
TRITON_MAX_RSPLIT = 64
|
|
|
|
|
|
class ReductionHint(Enum):
|
|
INNER = 0
|
|
OUTER = 1
|
|
OUTER_TINY = 2
|
|
DEFAULT = 3
|
|
|
|
|
|
class TileHint(Enum):
|
|
SQUARE = 0
|
|
DEFAULT = 1
|
|
|
|
|
|
# Define `AttrsDescriptorWrapper` function with clear conditional handling
|
|
if has_triton_package():
|
|
import triton
|
|
import triton.backends.compiler
|
|
import triton.compiler.compiler
|
|
|
|
if hasattr(triton.backends.compiler, "AttrsDescriptor"):
|
|
# Triton 3.2.0 - the second implementation
|
|
from triton.backends.compiler import AttrsDescriptor
|
|
|
|
def AttrsDescriptorWrapper(
|
|
divisible_by_16=None,
|
|
equal_to_1=None,
|
|
):
|
|
# Prepare the arguments for AttrsDescriptor
|
|
kwargs = {
|
|
"tt.divisibility": divisible_by_16,
|
|
"tt.equal_to": equal_to_1,
|
|
}
|
|
|
|
# Instantiate AttrsDescriptor with the prepared arguments
|
|
res = AttrsDescriptor.from_dict(
|
|
{"arg_properties": kwargs, "cls": AttrsDescriptor.__name__}
|
|
)
|
|
assert res.property_values["tt.divisibility"] == 16
|
|
assert res.property_values["tt.equal_to"] == 1
|
|
return res
|
|
|
|
elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
|
|
# Triton 3.0.0 - the original implementation
|
|
from triton.compiler.compiler import AttrsDescriptor
|
|
|
|
def AttrsDescriptorWrapper(
|
|
divisible_by_16=None,
|
|
equal_to_1=None,
|
|
):
|
|
# Prepare the arguments for AttrsDescriptor
|
|
kwargs = {
|
|
"divisible_by_16": divisible_by_16,
|
|
"equal_to_1": equal_to_1,
|
|
}
|
|
|
|
# Instantiate AttrsDescriptor with the prepared arguments
|
|
return AttrsDescriptor(**kwargs)
|
|
|
|
else:
|
|
# Triton in 2025:
|
|
# note: there's also a range of triton commits not currently supported
|
|
# from ~Dec 9, 2024 to Jan 1 2025, in which AttrsDescriptors are still
|
|
# used, but the contents are different.
|
|
|
|
def AttrsDescriptorWrapper(
|
|
divisible_by_16=None,
|
|
equal_to_1=None,
|
|
):
|
|
return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16}
|
|
|
|
else:
|
|
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
|
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
|
|
"AttrsDescriptor",
|
|
["divisible_by_16", "equal_to_1"],
|
|
defaults=[(), ()],
|
|
)
|
|
|
|
|
|
_NUM_THREADS_PER_WARP = 32
|
|
|
|
|
|
class HeuristicType(Enum):
|
|
PERSISTENT_REDUCTION = auto()
|
|
POINTWISE = auto()
|
|
REDUCTION = auto()
|
|
SPLIT_SCAN = auto()
|
|
TEMPLATE = auto()
|
|
USER_AUTOTUNE = auto()
|
|
FIXED = auto()
|
|
|
|
|
|
class AutotuneHint(Enum):
|
|
ONE_ELEMENT_PER_THREAD = 0
|
|
|
|
# Triton codegen tries to codegen set of AutotuneHints.
|
|
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
|
# which isn't valid python.
|
|
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
|
__repr__ = Enum.__str__
|
|
|
|
|
|
class DeviceProperties(typing.NamedTuple):
|
|
"""Copy device properties into a data structure not requiring torch to be imported"""
|
|
|
|
type: str # type: ignore[assignment]
|
|
index: int # type: ignore[assignment]
|
|
multi_processor_count: int
|
|
cc: int
|
|
major: Optional[int] = None
|
|
regs_per_multiprocessor: Optional[int] = None
|
|
max_threads_per_multi_processor: Optional[int] = None
|
|
warp_size: Optional[int] = None
|
|
|
|
@classmethod
|
|
@functools.cache
|
|
def create(cls, device) -> DeviceProperties:
|
|
import torch
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
device_type = device.type
|
|
|
|
if torch.version.hip and device_type == "cuda":
|
|
device_type = "hip"
|
|
|
|
device_interface = get_interface_for_device(device)
|
|
props = device_interface.get_device_properties(device)
|
|
try:
|
|
multi_processor_count = props.multi_processor_count
|
|
except AttributeError:
|
|
if device_type == "xpu":
|
|
multi_processor_count = props.gpu_subslice_count
|
|
elif device_type == "mps":
|
|
# TODO: Fetch the actual value from ioreg
|
|
multi_processor_count = 8
|
|
elif device_type == "mtia":
|
|
multi_processor_count = 64
|
|
else:
|
|
raise
|
|
return cls(
|
|
type=device_type,
|
|
index=device.index,
|
|
multi_processor_count=multi_processor_count,
|
|
cc=device_interface.get_compute_capability(device),
|
|
major=getattr(props, "major", None),
|
|
regs_per_multiprocessor=getattr(props, "regs_per_multiprocessor", None),
|
|
max_threads_per_multi_processor=getattr(
|
|
props, "max_threads_per_multi_processor", None
|
|
),
|
|
warp_size=getattr(props, "warp_size", 32 if device_type != "cpu" else None),
|
|
)
|
|
|
|
|
|
class HalideInputSpec(typing.NamedTuple):
|
|
ctype: str
|
|
name: str
|
|
shape: Optional[list[str]] = None
|
|
stride: Optional[list[str]] = None
|
|
offset: Optional[str] = None
|
|
alias_of: Optional[str] = None
|
|
|
|
def bindings_type(self) -> str:
|
|
if self.ctype in ("at::Half*", "at::BFloat16*"):
|
|
return "uint16_t*" # half not defined
|
|
return self.ctype
|
|
|
|
def halide_type(self) -> str:
|
|
if self.ctype == "at::Half*":
|
|
return "halide_type_t(halide_type_float, 16)" # half not defined
|
|
if self.ctype == "at::BFloat16*":
|
|
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
|
|
return f"halide_type_of<{self.ctype.replace('*', '')}>()"
|
|
|
|
def is_scalar(self) -> bool:
|
|
return self.shape is None
|
|
|
|
def is_buffer(self) -> bool:
|
|
return self.shape is not None
|
|
|
|
|
|
class HalideMeta(typing.NamedTuple):
|
|
argtypes: list[HalideInputSpec]
|
|
target: str
|
|
scheduler: Optional[str] = None
|
|
scheduler_flags: Optional[dict[str, Union[int, str]]] = None
|
|
cuda_device: Optional[int] = None
|
|
|
|
def args(self) -> list[str]:
|
|
"""Command line args to pass to halide generator"""
|
|
args = [f"target={self.target}"]
|
|
if self.scheduler:
|
|
args.append(f"autoscheduler={self.scheduler}")
|
|
if self.scheduler_flags:
|
|
assert self.scheduler
|
|
for k, v in self.scheduler_flags.items():
|
|
args.append(f"autoscheduler.{k}={v}")
|
|
return args
|
|
|
|
def is_cuda(self) -> bool:
|
|
return self.cuda_device is not None
|