[innductor] make inductor work with new triton compile interface (#115878)

Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API.

Also there is some simplification between compilation call in subprocess and the one in main process
- previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that
- previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process.

Updated:
There are more interface change from triton side. E.g.
- tl.math.{min, max} now requires a propagate_nan argument
- JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton.
- triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878
Approved by: https://github.com/jansel
This commit is contained in:
Shunting Zhang 2023-12-20 11:44:42 -08:00 committed by PyTorch MergeBot
parent 5d5ef016a6
commit bbded928b3
10 changed files with 162 additions and 34 deletions

View File

@ -2,7 +2,7 @@
import sympy
from torch._inductor.codegen.cpp import cexpr
from torch._inductor.codegen.triton import texpr
from torch._inductor.codegen.triton import texpr, TritonPrinter
from torch._inductor.codegen.wrapper import pexpr
from torch._inductor.sizevars import SizeVarAllocator
@ -291,14 +291,18 @@ class ExprPrinterTests(TorchTestCase):
(sympy.Min, "min"),
(sympy.Max, "max"),
)
extra_arg = TritonPrinter._propagate_nan_arg()
for f, s in cases:
x = sympy.Symbol("x", integer=True)
expr = f(-2, x)
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x)")
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x{extra_arg})")
self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)")
expr = f(x, 2 * x, 3 * x)
self.assertEqual(texpr(expr), f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x))")
self.assertEqual(
texpr(expr),
f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x{extra_arg}){extra_arg})",
)
self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})")

View File

@ -2375,3 +2375,14 @@ def to_fake_tensor(t, fake_mode):
return fake_mode.from_tensor(
t, static_shapes=False, symbolic_context=symbolic_context, source=source
)
def get_first_attr(obj, *attrs):
"""
Return the first available attribute or throw an exception if none is present.
"""
for attr in attrs:
if hasattr(obj, attr):
return getattr(obj, attr)
raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")

View File

@ -10,7 +10,7 @@ from .. import variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import make_cell
from ..utils import get_first_attr, make_cell
from .base import typestr, VariableTracker
@ -654,9 +654,20 @@ class TritonKernelVariable(VariableTracker):
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner).parameters
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
("warmup" in defaults and defaults["warmup"].default != kernel.warmup)
or ("rep" in defaults and defaults["rep"].default != kernel.rep)
(
"warmup" in defaults
and defaults["warmup"].default
!= get_first_attr(kernel, "num_warmups", "warmup")
)
or (
"rep" in defaults
and defaults["rep"].default
!= get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default

View File

@ -513,6 +513,15 @@ class TritonBenchmarkRequest(BenchmarkRequest):
)
run_method = getattr(mod, self.kernel_name).run
extra_args = list(self.extra_args)
# Newer version of triton add warmup argument to JITFunction.run.
# This code handles backward-compatibility.
warmup_arg = {}
import inspect
if "warmup" in inspect.signature(run_method).parameters:
warmup_arg["warmup"] = False
return functools.partial(
run_method,
@ -520,9 +529,9 @@ class TritonBenchmarkRequest(BenchmarkRequest):
output_tensor,
*self.extra_args,
grid=self.grid,
**warmup_arg,
num_stages=self.num_stages,
num_warps=self.num_warps,
stream=torch.cuda.current_stream().cuda_stream,
)
def __str__(self) -> str:

View File

@ -10,6 +10,7 @@ import math
import operator
import os
import textwrap
from functools import lru_cache
from typing import (
Any,
Callable,
@ -31,6 +32,7 @@ import torch._logging
from torch._prims_common import is_integer_dtype
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils._triton import has_triton_package
from ..._dynamo.utils import counters
from .. import config, ir, scheduler
@ -91,6 +93,30 @@ class TritonPrinter(PythonPrinter):
q = self.doprint(expr.args[2])
return f"tl.where({c}, {p}, {q})"
@staticmethod
@lru_cache(None)
def _propagate_nan_arg():
"""
Newer triton version added propagate_nan as required argument for
tl.math.{min, max}. This method make inductor work with both old
and new version of triton.
"""
if not has_triton_package():
# some tests run under environment without triton installed want to
# check that the generated code is as expected.
return ""
import inspect
import triton.language as tl
if "propagate_nan" in inspect.signature(tl.math.min).parameters:
# tl.PropagateNan.NONE is the default
propagate_nan_arg = ", tl.PropagateNan.NONE"
else:
propagate_nan_arg = ""
return propagate_nan_arg
def _print_Min(self, expr):
nargs = len(expr.args)
if len(expr.args) == 1:
@ -99,7 +125,7 @@ class TritonPrinter(PythonPrinter):
mid = len(expr.args) // 2
a = self._print(sympy.Min(*expr.args[:mid]))
b = self._print(sympy.Min(*expr.args[mid:]))
return f"tl.math.min({a}, {b})"
return f"tl.math.min({a}, {b}{TritonPrinter._propagate_nan_arg()})"
def _print_Max(self, expr):
nargs = len(expr.args)
@ -109,7 +135,8 @@ class TritonPrinter(PythonPrinter):
mid = len(expr.args) // 2
a = self._print(sympy.Max(*expr.args[:mid]))
b = self._print(sympy.Max(*expr.args[mid:]))
return f"tl.math.max({a}, {b})"
return f"tl.math.max({a}, {b}{TritonPrinter._propagate_nan_arg()})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
@ -2072,6 +2099,20 @@ class TritonKernel(Kernel):
"""
)
@staticmethod
@lru_cache(None)
def gen_attr_descriptor_import():
"""
import AttrsDescriptor if the triton version is new enough to have this
class defined.
"""
import triton.compiler.compiler
if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
return "from triton.compiler.compiler import AttrsDescriptor"
else:
return ""
def codegen_kernel(self, name=None):
from triton import next_power_of_2
@ -2117,6 +2158,9 @@ class TritonKernel(Kernel):
from torch._inductor import triton_helpers
"""
)
if self.gen_attr_descriptor_import():
code.splice(self.gen_attr_descriptor_import())
if config.benchmark_kernel:
code.splice(self.imports_for_benchmark_kernel())

View File

@ -189,6 +189,8 @@ class ForeachKernel(Kernel):
from torch._inductor import triton_helpers
"""
)
if TritonKernel.gen_attr_descriptor_import():
code.splice(TritonKernel.gen_attr_descriptor_import())
argdefs, _, _ = self.args.python_argdefs()
code.writeline(self.jit_line())
code.writeline(

View File

@ -926,6 +926,10 @@ class WrapperCodeGen(CodeGen):
""",
strip=True,
)
from .triton import TritonKernel
if TritonKernel.gen_attr_descriptor_import():
compile_wrapper.splice(TritonKernel.gen_attr_descriptor_import())
compile_wrapper.newline()
from .common import SizeArg, TensorArg

View File

@ -187,6 +187,7 @@ class TritonTemplateKernel(TritonKernel):
"from torch._inductor.triton_heuristics import template",
"from torch._inductor.utils import instance_descriptor",
"from torch._inductor import triton_helpers",
TritonKernel.gen_attr_descriptor_import(),
"",
self.jit_line(),
f"def {self.kernel_name}({', '.join(arg_defs)}):",

View File

@ -18,7 +18,7 @@ import torch
import torch.autograd.profiler as autograd_profiler
from torch._dynamo.device_interface import get_interface_for_device
from torch._dynamo.utils import dynamo_timed
from torch._dynamo.utils import dynamo_timed, get_first_attr
from torch.utils._triton import has_triton_package
from . import config
@ -44,11 +44,17 @@ if has_triton_package():
from triton import Config
from triton.runtime.autotuner import OutOfResources
from triton.runtime.jit import KernelInterface
try:
from triton.compiler.compiler import ASTSource
except ImportError:
ASTSource = None
else:
Config = object
triton = None
KernelInterface = object
OutOfResources = object
ASTSource = None
_NUM_THREADS_PER_WARP = 32
@ -286,14 +292,44 @@ class CachingAutotuner(KernelInterface):
# Setting device_type="hip" required on ROCm to pass down to triton
compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip"
device_type = compile_meta["device_type"]
if warm_cache_only_with_cc:
cc = warm_cache_only_with_cc
else:
device_id = compile_meta["device"]
device_interface = get_interface_for_device(device_type)
device = torch.device(device_type, device_id)
cc = device_interface.get_compute_capability(device)
compile_meta["cc"] = cc
if ASTSource:
compile_args = (
ASTSource(
self.fn,
compile_meta["signature"],
compile_meta["constants"],
compile_meta["configs"][0],
),
)
target = (device_type, cc)
options = {
"num_warps": compile_meta["num_warps"],
"num_stages": compile_meta["num_stages"],
"debug": compile_meta["debug"],
}
compile_kwargs = {
"target": target,
"options": options,
}
else:
compile_args = (self.fn,)
compile_kwargs = compile_meta
if warm_cache_only_with_cc:
return (
triton.compile(
self.fn,
warm_cache_only=True,
cc=warm_cache_only_with_cc,
**compile_meta,
),
triton.compile(*compile_args, **compile_kwargs),
None,
)
@ -301,10 +337,8 @@ class CachingAutotuner(KernelInterface):
with torch.cuda.device(compile_meta["device"]):
# need to initialize context
torch.cuda.synchronize(torch.cuda.current_device())
binary = triton.compile(
self.fn,
**compile_meta,
)
binary = triton.compile(*compile_args, **compile_kwargs)
binary._init_handles()
call_args = [
@ -321,6 +355,14 @@ class CachingAutotuner(KernelInterface):
"set_device": torch.cuda.set_device,
"current_device": torch.cuda.current_device,
}
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
scope["function"] = get_first_attr(binary, "function", "cu_function")
cluster_dims = get_first_attr(binary, "cluster_dims", "clusterDims")
scope["cta_args"] = (
(binary.num_ctas, *cluster_dims) if hasattr(binary, "num_ctas") else ()
)
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
@ -329,15 +371,10 @@ class CachingAutotuner(KernelInterface):
else:
grid_0, grid_1, grid_2 = grid
if hasattr(bin, "num_ctas"):
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps,
bin.num_ctas, *bin.clusterDims, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
else:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
stream, bin.cu_function, None, None, None,
{', '.join(call_args)})
runner(grid_0, grid_1, grid_2, bin.num_warps,
*cta_args, bin.shared,
stream, function, None, None, None,
{', '.join(call_args)})
return bin
""".lstrip(),
scope,

View File

@ -609,11 +609,16 @@ def has_incompatible_cudagraph_ops(gm):
return False
instance_descriptor = collections.namedtuple(
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)
try:
from triton.compiler.compiler import AttrsDescriptor as instance_descriptor
except ImportError:
# To support older version of triton which does not have AttrsDescriptor
# class
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
"instance_descriptor",
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
defaults=[tuple(), tuple(), tuple(), tuple()],
)
@functools.lru_cache(None)