pytorch/torch/_inductor/runtime/hints.py
Jason Ansel 2b937e4e6d [inductor] Cooperative reductions (#137756)
Example generated code for `(x+y).sum()`:
```py
@triton.jit
def triton_unk_fused_add_sum_0(in_ptr0, in_ptr1, out_ptr0, ws_ptr, semaphores_ptr, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr, RSPLIT : tl.constexpr):
    xnumel = 1
    rnumel = 1048576
    rsplit_id = tl.program_id(0)
    num_rblocks = (rnumel + RBLOCK - 1) // RBLOCK
    rsplit_chunk = (num_rblocks + RSPLIT - 1) // RSPLIT * RBLOCK
    rsplit_start = rsplit_chunk * rsplit_id
    rsplit_end = rsplit_chunk * (rsplit_id + 1)
    xoffset = tl.program_id(1) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    rbase = tl.arange(0, RBLOCK)[None, :]
    _tmp4 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
    for roffset in range(rsplit_start, rsplit_end, RBLOCK):
        rindex = roffset + rbase
        rmask = rindex < rnumel
        r0 = rindex
        tmp0 = tl.load(in_ptr0 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp1 = tl.load(in_ptr1 + (r0), rmask, eviction_policy='evict_first', other=0.0)
        tmp2 = tmp0 + tmp1
        tmp3 = tl.broadcast_to(tmp2, [XBLOCK, RBLOCK])
        tmp5 = _tmp4 + tmp3
        _tmp4 = tl.where(rmask, tmp5, _tmp4)
    tmp4 = tl.sum(_tmp4, 1)[:, None]
    if RSPLIT > 1:
        tmp4_ws = (ws_ptr + 0).to(tl.pointer_type(tl.float32))
        tl.store(tmp4_ws + (xindex * RSPLIT + rsplit_id), tmp4, None)
    if RSPLIT > 1:
        triton_helpers.gpu_barrier(semaphores_ptr + (2 * tl.program_id(1) + 0), RSPLIT, True)
    if RSPLIT > 1:
        tmp4_peers = tl.load(tmp4_ws + (xindex * RSPLIT + tl.arange(0, RSPLIT)[None,:]), None, eviction_policy='evict_first')
        tmp4 = tl.sum(tmp4_peers, 1)[:, None]
    if rsplit_id == (0 % RSPLIT):
        tl.store(out_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp4, None)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137756
Approved by: https://github.com/eellison
2024-10-29 00:45:53 +00:00

202 lines
5.9 KiB
Python

# mypy: allow-untyped-defs
import collections
import typing
from enum import auto, Enum
from typing import Dict, List, Optional, Union
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"Y": 1024,
"Z": 1024,
"R": 4096 * 16, # * 16 is multi-kernel only
}
TRITON_MAX_RSPLIT = 64
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
def _is_triton_available():
try:
import triton # noqa: F401
return True
except ImportError:
return False
# Define `AttrsDescriptorWrapper` function with clear conditional handling
if _is_triton_available():
try:
from triton.backends.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"tt.divisibility": divisible_by_16,
"tt.equal_to": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
res = AttrsDescriptor.from_dict(kwargs)
assert res.property_values["tt.divisibility"] == 16
assert res.property_values["tt.equal_to"] == 1
return res
except ImportError:
from triton.compiler.compiler import AttrsDescriptor
def AttrsDescriptorWrapper(
divisible_by_16=None,
equal_to_1=None,
):
# Prepare the arguments for AttrsDescriptor
kwargs = {
"divisible_by_16": divisible_by_16,
"equal_to_1": equal_to_1,
}
# Instantiate AttrsDescriptor with the prepared arguments
return AttrsDescriptor(**kwargs)
else:
# Define a namedtuple as a fallback when AttrsDescriptor is not available
AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match]
"AttrsDescriptor",
["divisible_by_16", "equal_to_1"],
defaults=[(), ()],
)
_NUM_THREADS_PER_WARP = 32
class HeuristicType(Enum):
PERSISTENT_REDUCTION = auto()
POINTWISE = auto()
REDUCTION = auto()
SPLIT_SCAN = auto()
TEMPLATE = auto()
USER_AUTOTUNE = auto()
class AutotuneHint(Enum):
ONE_ELEMENT_PER_THREAD = 0
# Triton codegen tries to codegen set of AutotuneHints.
# Enum.__repr__ looks like "<AutotuneHint.ELEMENTS_PER_WARP_32: 0>""
# which isn't valid python.
# Enum.__str__ will just return "AutotuneHint.ELEMENTS_PER_WARP_32".
__repr__ = Enum.__str__
class DeviceProperties(typing.NamedTuple):
"""Copy device properties into a data structure not requiring torch to be imported"""
type: str # type: ignore[assignment]
index: int # type: ignore[assignment]
cc: int
major: Optional[int] = None
regs_per_multiprocessor: Optional[int] = None
max_threads_per_multi_processor: Optional[int] = None
multi_processor_count: Optional[int] = None
warp_size: Optional[int] = None
@classmethod
def create(cls, device):
import torch
from torch._dynamo.device_interface import get_interface_for_device
device_type = device.type
if torch.version.hip and device_type == "cuda":
device_type = "hip"
device_interface = get_interface_for_device(device)
if device_type in ["cuda", "hip", "xpu"]:
props = device_interface.get_device_properties(device)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
major=props.major if hasattr(props, "major") else None,
regs_per_multiprocessor=props.regs_per_multiprocessor
if hasattr(props, "regs_per_multiprocessor")
else None,
max_threads_per_multi_processor=props.max_threads_per_multi_processor
if hasattr(props, "max_threads_per_multi_processor")
else None,
multi_processor_count=props.multi_processor_count
if hasattr(props, "multi_processor_count")
else None,
warp_size=props.warp_size if hasattr(props, "warp_size") else 32,
)
return cls(
type=device_type,
index=device.index,
cc=device_interface.get_compute_capability(device),
)
class HalideInputSpec(typing.NamedTuple):
ctype: str
name: str
shape: Optional[List[str]] = None
stride: Optional[List[str]] = None
offset: Optional[str] = None
alias_of: Optional[str] = None
def bindings_type(self):
if self.ctype in ("half*", "bfloat16*"):
return "uint16_t*" # half not defined
return self.ctype
def halide_type(self):
if self.ctype == "half*":
return "halide_type_t(halide_type_float, 16)" # half not defined
if self.ctype == "bfloat16*":
return "halide_type_t(halide_type_bfloat, 16)" # half not defined
return f"halide_type_of<{self.ctype.replace('*', '')}>()"
def is_scalar(self):
return self.shape is None
def is_buffer(self):
return self.shape is not None
class HalideMeta(typing.NamedTuple):
argtypes: List[HalideInputSpec]
target: str
scheduler: Optional[str] = None
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
cuda_device: Optional[int] = None
def args(self):
"""Command line args to pass to halide generator"""
args = [f"target={self.target}"]
if self.scheduler:
args.append(f"autoscheduler={self.scheduler}")
if self.scheduler_flags:
assert self.scheduler
for k, v in self.scheduler_flags.items():
args.append(f"autoscheduler.{k}={v}")
return args
def is_cuda(self):
return self.cuda_device is not None