diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3d427fd7dd0..0b82dfda835 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1543,35 +1543,87 @@ def use_triton_template( ) -def use_triton_tma_template(*matrices: IRNode) -> bool: +def use_triton_tma_template(*matrices: IRNode, add_guards: bool = False) -> bool: + """ + Return True iff *all* supplied tensors satisfy the CUDA-12.9 TMA constraints + that Triton relies on today. + * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html + + A tensor is accepted when: + * 2 ≤ rank ≤ 5 + * dtype ∈ {FP16, BF16, FP8-E4M3FN} + * Every logical size ≥ 2 + * Base pointer 16-byte aligned + * All "outer" dims have 16-byte aligned strides + * The “inner” dim has stride 1 (contiguous) + * For FP8 tensors, inner dim ≥ 32 + """ from torch.utils._triton import has_triton_tma_device from .virtualized import V + def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool: + return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT) + def _is_tma_compatible(x: IRNode) -> bool: - if len(x.get_size()) != 2: + sizes = x.get_size() + strides = x.get_stride() + rank = len(sizes) + dtype = x.get_dtype() + itemsize = dtype.itemsize + + # 2 ≤ rank ≤ 5 + if rank < 2 or rank > 5: return False - dtype = x.get_dtype() + # dtype ∈ {FP16, BF16, FP8-E4M3FN} if dtype not in (torch.float16, torch.bfloat16, torch.float8_e4m3fn): return False - layout = x.get_layout() - transposed = layout.is_transposed() - if not (layout.is_contiguous() or transposed): + # Base pointer 16-byte aligned + if x.get_name() in V.graph.unaligned_buffers: return False - inner_dim = layout.size[1] - if transposed: - inner_dim = layout.size[0] + if add_guards: + sizes_i = V.graph.sizevars.guard_int_seq(sizes) + strides_i = V.graph.sizevars.guard_int_seq(strides) + else: + sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes] + strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides] - if dtype == torch.float8_e4m3fn and V.graph.sizevars.statically_known_lt( + # Every logical size ≥ 2 + if any(not V.graph.sizevars.statically_known_geq(s, 2) for s in sizes_i): + return False + + # Find the single contiguous (“inner”) dim + inner = [ + i + for i, st in enumerate(strides_i) + if V.graph.sizevars.statically_known_equals(st, 1) + ] + if len(inner) != 1: + return False + inner_idx = inner[0] + + # All "outer" dims must have 16-byte aligned strides + for i, st in enumerate(strides_i): + if i == inner_idx: + continue + if not _aligned(st * itemsize): + return False + + # Inner dim byte width must still be a multiple of 16 B + inner_dim = sizes_i[inner_idx] + if not _aligned(inner_dim * itemsize): + return False + + # FP8 special case: inner ≥ 32 + if dtype == torch.float8_e4m3fn and not V.graph.sizevars.statically_known_geq( inner_dim, 32 ): return False - inner_bytes = inner_dim * dtype.itemsize - return V.graph.sizevars.statically_known_multiple_of(inner_bytes, TMA_ALIGNMENT) + return True return ( config.triton.enable_persistent_tma_matmul