mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
xnumel = 1
rnumel = 1048576
rsplit_id = tl.program_id(0)
num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
rsplit_start = rsplit_chunk * rsplit_id
rsplit_end = rsplit_chunk * (rsplit_id + 1)
xoffset = tl.program_id(1) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
rbase = tl.arange(0, RBLOCK)[None, :]
_tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(rsplit_start, rsplit_end, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r0 = rindex
tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
tmp2 = tmp0 + tmp1
tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
tmp5 = _tmp4 + tmp3
_tmp4 = tl.where(rmask, tmp5, _tmp4)
tmp4 = tl.sum(_tmp4, 1)[:, None]
if RSPLIT > 1:
tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
if RSPLIT > 1:
triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
if RSPLIT > 1:
tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
tmp4 = tl.sum(tmp4_peers, 1)[:, None]
if rsplit_id == (0 % RSPLIT):
tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
202 lines
5.9 KiB
Python
202 lines
5.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import collections
|
|
import typing
|
|
from enum import auto, Enum
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
|
|
# NOTE: if these fail asserts submit a PR to increase them
|
|
TRITON_MAX_BLOCK = {
|
|
"X": 4096,
|
|
"Y": 1024,
|
|
"Z": 1024,
|
|
"R": 4096 * 16, # * 16 is multi-kernel only
|
|
}
|
|
TRITON_MAX_RSPLIT = 64
|
|
|
|
|
|
class ReductionHint(Enum):
|
|
INNER = 0
|
|
OUTER = 1
|
|
OUTER_TINY = 2
|
|
DEFAULT = 3
|
|
|
|
|
|
class TileHint(Enum):
|
|
SQUARE = 0
|
|
DEFAULT = 1
|
|
|
|
|
|
def _is_triton_available():
|
|
try:
|
|
import triton # noqa: F401
|
|
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
# Define `AttrsDescriptorWrapper` function with clear conditional handling
|
|
if _is_triton_available():
|
|
try:
|
|
from triton.backends.compiler import AttrsDescriptor
|
|
|
|
def AttrsDescriptorWrapper(
|
|
divisible_by_16=None,
|
|
equal_to_1=None,
|
|
):
|
|
# Prepare the arguments for AttrsDescriptor
|
|
kwargs = {
|
|
"tt.divisibility": divisible_by_16,
|
|
"tt.equal_to": equal_to_1,
|
|
}
|
|
|
|
# Instantiate AttrsDescriptor with the prepared arguments
|
|
res = AttrsDescriptor.from_dict(kwargs)
|
|
assert res.property_values["tt.divisibility"] == 16
|
|
assert res.property_values["tt.equal_to"] == 1
|
|
return res
|
|
|
|
except ImportError:
|
|
from triton.compiler.compiler import AttrsDescriptor
|
|
|
|
def AttrsDescriptorWrapper(
|
|
divisible_by_16=None,
|
|
equal_to_1=None,
|
|
):
|
|
# Prepare the arguments for AttrsDescriptor
|
|
kwargs = {
|
|
"divisible_by_16": divisible_by_16,
|
|
"equal_to_1": equal_to_1,
|
|
}
|
|
|
|
# Instantiate AttrsDescriptor with the prepared arguments
|
|
return AttrsDescriptor(**kwargs)
|
|
|
|
else:
|
|
# Define a namedtuple as a fallback when AttrsDescriptor is not available
|
|
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
|
|
"AttrsDescriptor",
|
|
["divisible_by_16", "equal_to_1"],
|
|
defaults=[(), ()],
|
|
)
|
|
|
|
|
|
_NUM_THREADS_PER_WARP = 32
|
|
|
|
|
|
class HeuristicType(Enum):
|
|
PERSISTENT_REDUCTION = auto()
|
|
POINTWISE = auto()
|
|
REDUCTION = auto()
|
|
SPLIT_SCAN = auto()
|
|
TEMPLATE = auto()
|
|
USER_AUTOTUNE = auto()
|
|
|
|
|
|
class AutotuneHint(Enum):
|
|
ONE_ELEMENT_PER_THREAD = 0
|
|
|
|
# Triton codegen tries to codegen set of AutotuneHints.
|
|
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
|
|
# which isn't valid python.
|
|
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
|
|
__repr__ = Enum.__str__
|
|
|
|
|
|
class DeviceProperties(typing.NamedTuple):
|
|
"""Copy device properties into a data structure not requiring torch to be imported"""
|
|
|
|
type: str # type: ignore[assignment]
|
|
index: int # type: ignore[assignment]
|
|
cc: int
|
|
major: Optional[int] = None
|
|
regs_per_multiprocessor: Optional[int] = None
|
|
max_threads_per_multi_processor: Optional[int] = None
|
|
multi_processor_count: Optional[int] = None
|
|
warp_size: Optional[int] = None
|
|
|
|
@classmethod
|
|
def create(cls, device):
|
|
import torch
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
device_type = device.type
|
|
|
|
if torch.version.hip and device_type == "cuda":
|
|
device_type = "hip"
|
|
|
|
device_interface = get_interface_for_device(device)
|
|
if device_type in ["cuda", "hip", "xpu"]:
|
|
props = device_interface.get_device_properties(device)
|
|
return cls(
|
|
type=device_type,
|
|
index=device.index,
|
|
cc=device_interface.get_compute_capability(device),
|
|
major=props.major if hasattr(props, "major") else None,
|
|
regs_per_multiprocessor=props.regs_per_multiprocessor
|
|
if hasattr(props, "regs_per_multiprocessor")
|
|
else None,
|
|
max_threads_per_multi_processor=props.max_threads_per_multi_processor
|
|
if hasattr(props, "max_threads_per_multi_processor")
|
|
else None,
|
|
multi_processor_count=props.multi_processor_count
|
|
if hasattr(props, "multi_processor_count")
|
|
else None,
|
|
warp_size=props.warp_size if hasattr(props, "warp_size") else 32,
|
|
)
|
|
return cls(
|
|
type=device_type,
|
|
index=device.index,
|
|
cc=device_interface.get_compute_capability(device),
|
|
)
|
|
|
|
|
|
class HalideInputSpec(typing.NamedTuple):
|
|
ctype: str
|
|
name: str
|
|
shape: Optional[List[str]] = None
|
|
stride: Optional[List[str]] = None
|
|
offset: Optional[str] = None
|
|
alias_of: Optional[str] = None
|
|
|
|
def bindings_type(self):
|
|
if self.ctype in ("half*", "bfloat16*"):
|
|
return "uint16_t*" # half not defined
|
|
return self.ctype
|
|
|
|
def halide_type(self):
|
|
if self.ctype == "half*":
|
|
return "halide_type_t(halide_type_float, 16)" # half not defined
|
|
if self.ctype == "bfloat16*":
|
|
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
|
|
return f"halide_type_of<{self.ctype.replace('*', '')}>()"
|
|
|
|
def is_scalar(self):
|
|
return self.shape is None
|
|
|
|
def is_buffer(self):
|
|
return self.shape is not None
|
|
|
|
|
|
class HalideMeta(typing.NamedTuple):
|
|
argtypes: List[HalideInputSpec]
|
|
target: str
|
|
scheduler: Optional[str] = None
|
|
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
|
|
cuda_device: Optional[int] = None
|
|
|
|
def args(self):
|
|
"""Command line args to pass to halide generator"""
|
|
args = [f"target={self.target}"]
|
|
if self.scheduler:
|
|
args.append(f"autoscheduler={self.scheduler}")
|
|
if self.scheduler_flags:
|
|
assert self.scheduler
|
|
for k, v in self.scheduler_flags.items():
|
|
args.append(f"autoscheduler.{k}={v}")
|
|
return args
|
|
|
|
def is_cuda(self):
|
|
return self.cuda_device is not None
|