mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[innductor] make inductor work with new triton compile interface (#115878)
Recent 2 triton PRs (https://github.com/openai/triton/pull/2701, https://github.com/openai/triton/pull/2756) change the interface for triton.compile, this PR added the necessary change on inductor side to work with both old and new compile API. Also there is some simplification between compilation call in subprocess and the one in main process - previously we pass warm_cache_only=True if the compilation happens in subprocess. But triton never use that argument in the currently used pin. So I removed that - previously we only pass compute_capability if compilation happens in subprocess. The PR change that to always passing compute_capability to triton.compile no matter if the compilation happens in main or sub process. Updated: There are more interface change from triton side. E.g. - tl.math.{min, max} now requires a propagate_nan argument - JITFunction.run now requires a warmup argument. This affect the benchmarking phase of matmul max-autotune; on the other hand, JITFunction.run forbids stream argument now. Simply removing passing this in when benchmarking matmul triton kernel will work for both old and new version of triton. - triton Autotuner change attribute name from 'warmup' to 'num_warmup' and from 'rep' to 'num_rep'. This cause dynamo failed to handle triton Autotuner object since dynamo TritonKernelVariable makes assumption about attribute names. It's used in some test cases that a model call triton Autotuner directly. Pull Request resolved: https://github.com/pytorch/pytorch/pull/115878 Approved by: https://github.com/jansel
This commit is contained in:
parent
5d5ef016a6
commit
bbded928b3
|
|
@ -2,7 +2,7 @@
|
|||
import sympy
|
||||
|
||||
from torch._inductor.codegen.cpp import cexpr
|
||||
from torch._inductor.codegen.triton import texpr
|
||||
from torch._inductor.codegen.triton import texpr, TritonPrinter
|
||||
from torch._inductor.codegen.wrapper import pexpr
|
||||
|
||||
from torch._inductor.sizevars import SizeVarAllocator
|
||||
|
|
@ -291,14 +291,18 @@ class ExprPrinterTests(TorchTestCase):
|
|||
(sympy.Min, "min"),
|
||||
(sympy.Max, "max"),
|
||||
)
|
||||
extra_arg = TritonPrinter._propagate_nan_arg()
|
||||
for f, s in cases:
|
||||
x = sympy.Symbol("x", integer=True)
|
||||
expr = f(-2, x)
|
||||
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x)")
|
||||
self.assertEqual(texpr(expr), f"tl.math.{s}(-2, x{extra_arg})")
|
||||
self.assertEqual(cexpr(expr), f"std::{s}(-2L, x)")
|
||||
|
||||
expr = f(x, 2 * x, 3 * x)
|
||||
self.assertEqual(texpr(expr), f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x))")
|
||||
self.assertEqual(
|
||||
texpr(expr),
|
||||
f"tl.math.{s}(x, tl.math.{s}(2*x, 3*x{extra_arg}){extra_arg})",
|
||||
)
|
||||
self.assertEqual(cexpr(expr), f"std::{s}({{x, 2L*x, 3L*x}})")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2375,3 +2375,14 @@ def to_fake_tensor(t, fake_mode):
|
|||
return fake_mode.from_tensor(
|
||||
t, static_shapes=False, symbolic_context=symbolic_context, source=source
|
||||
)
|
||||
|
||||
|
||||
def get_first_attr(obj, *attrs):
|
||||
"""
|
||||
Return the first available attribute or throw an exception if none is present.
|
||||
"""
|
||||
for attr in attrs:
|
||||
if hasattr(obj, attr):
|
||||
return getattr(obj, attr)
|
||||
|
||||
raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from .. import variables
|
|||
from ..bytecode_transformation import create_call_function, create_rot_n
|
||||
from ..exc import unimplemented, Unsupported
|
||||
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
||||
from ..utils import make_cell
|
||||
from ..utils import get_first_attr, make_cell
|
||||
from .base import typestr, VariableTracker
|
||||
|
||||
|
||||
|
|
@ -654,9 +654,20 @@ class TritonKernelVariable(VariableTracker):
|
|||
# We only support configs and keys arguments of triton.autotune
|
||||
# Make sure other arguments are defaulted
|
||||
defaults = inspect.signature(Autotuner).parameters
|
||||
|
||||
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
|
||||
# The call to get_first_attr is to maintain backward-compatibility.
|
||||
if (
|
||||
("warmup" in defaults and defaults["warmup"].default != kernel.warmup)
|
||||
or ("rep" in defaults and defaults["rep"].default != kernel.rep)
|
||||
(
|
||||
"warmup" in defaults
|
||||
and defaults["warmup"].default
|
||||
!= get_first_attr(kernel, "num_warmups", "warmup")
|
||||
)
|
||||
or (
|
||||
"rep" in defaults
|
||||
and defaults["rep"].default
|
||||
!= get_first_attr(kernel, "num_reps", "rep")
|
||||
)
|
||||
or (
|
||||
"prune_configs_by" in defaults
|
||||
and defaults["prune_configs_by"].default
|
||||
|
|
|
|||
|
|
@ -513,6 +513,15 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
)
|
||||
|
||||
run_method = getattr(mod, self.kernel_name).run
|
||||
extra_args = list(self.extra_args)
|
||||
|
||||
# Newer version of triton add warmup argument to JITFunction.run.
|
||||
# This code handles backward-compatibility.
|
||||
warmup_arg = {}
|
||||
import inspect
|
||||
|
||||
if "warmup" in inspect.signature(run_method).parameters:
|
||||
warmup_arg["warmup"] = False
|
||||
|
||||
return functools.partial(
|
||||
run_method,
|
||||
|
|
@ -520,9 +529,9 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
output_tensor,
|
||||
*self.extra_args,
|
||||
grid=self.grid,
|
||||
**warmup_arg,
|
||||
num_stages=self.num_stages,
|
||||
num_warps=self.num_warps,
|
||||
stream=torch.cuda.current_stream().cuda_stream,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import math
|
|||
import operator
|
||||
import os
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
|
|
@ -31,6 +32,7 @@ import torch._logging
|
|||
from torch._prims_common import is_integer_dtype
|
||||
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config, ir, scheduler
|
||||
|
|
@ -91,6 +93,30 @@ class TritonPrinter(PythonPrinter):
|
|||
q = self.doprint(expr.args[2])
|
||||
return f"tl.where({c}, {p}, {q})"
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(None)
|
||||
def _propagate_nan_arg():
|
||||
"""
|
||||
Newer triton version added propagate_nan as required argument for
|
||||
tl.math.{min, max}. This method make inductor work with both old
|
||||
and new version of triton.
|
||||
"""
|
||||
|
||||
if not has_triton_package():
|
||||
# some tests run under environment without triton installed want to
|
||||
# check that the generated code is as expected.
|
||||
return ""
|
||||
import inspect
|
||||
|
||||
import triton.language as tl
|
||||
|
||||
if "propagate_nan" in inspect.signature(tl.math.min).parameters:
|
||||
# tl.PropagateNan.NONE is the default
|
||||
propagate_nan_arg = ", tl.PropagateNan.NONE"
|
||||
else:
|
||||
propagate_nan_arg = ""
|
||||
return propagate_nan_arg
|
||||
|
||||
def _print_Min(self, expr):
|
||||
nargs = len(expr.args)
|
||||
if len(expr.args) == 1:
|
||||
|
|
@ -99,7 +125,7 @@ class TritonPrinter(PythonPrinter):
|
|||
mid = len(expr.args) // 2
|
||||
a = self._print(sympy.Min(*expr.args[:mid]))
|
||||
b = self._print(sympy.Min(*expr.args[mid:]))
|
||||
return f"tl.math.min({a}, {b})"
|
||||
return f"tl.math.min({a}, {b}{TritonPrinter._propagate_nan_arg()})"
|
||||
|
||||
def _print_Max(self, expr):
|
||||
nargs = len(expr.args)
|
||||
|
|
@ -109,7 +135,8 @@ class TritonPrinter(PythonPrinter):
|
|||
mid = len(expr.args) // 2
|
||||
a = self._print(sympy.Max(*expr.args[:mid]))
|
||||
b = self._print(sympy.Max(*expr.args[mid:]))
|
||||
return f"tl.math.max({a}, {b})"
|
||||
|
||||
return f"tl.math.max({a}, {b}{TritonPrinter._propagate_nan_arg()})"
|
||||
|
||||
def _print_Abs(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
|
|
@ -2072,6 +2099,20 @@ class TritonKernel(Kernel):
|
|||
"""
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(None)
|
||||
def gen_attr_descriptor_import():
|
||||
"""
|
||||
import AttrsDescriptor if the triton version is new enough to have this
|
||||
class defined.
|
||||
"""
|
||||
import triton.compiler.compiler
|
||||
|
||||
if hasattr(triton.compiler.compiler, "AttrsDescriptor"):
|
||||
return "from triton.compiler.compiler import AttrsDescriptor"
|
||||
else:
|
||||
return ""
|
||||
|
||||
def codegen_kernel(self, name=None):
|
||||
from triton import next_power_of_2
|
||||
|
||||
|
|
@ -2117,6 +2158,9 @@ class TritonKernel(Kernel):
|
|||
from torch._inductor import triton_helpers
|
||||
"""
|
||||
)
|
||||
if self.gen_attr_descriptor_import():
|
||||
code.splice(self.gen_attr_descriptor_import())
|
||||
|
||||
if config.benchmark_kernel:
|
||||
code.splice(self.imports_for_benchmark_kernel())
|
||||
|
||||
|
|
|
|||
|
|
@ -189,6 +189,8 @@ class ForeachKernel(Kernel):
|
|||
from torch._inductor import triton_helpers
|
||||
"""
|
||||
)
|
||||
if TritonKernel.gen_attr_descriptor_import():
|
||||
code.splice(TritonKernel.gen_attr_descriptor_import())
|
||||
argdefs, _, _ = self.args.python_argdefs()
|
||||
code.writeline(self.jit_line())
|
||||
code.writeline(
|
||||
|
|
|
|||
|
|
@ -926,6 +926,10 @@ class WrapperCodeGen(CodeGen):
|
|||
""",
|
||||
strip=True,
|
||||
)
|
||||
from .triton import TritonKernel
|
||||
|
||||
if TritonKernel.gen_attr_descriptor_import():
|
||||
compile_wrapper.splice(TritonKernel.gen_attr_descriptor_import())
|
||||
compile_wrapper.newline()
|
||||
|
||||
from .common import SizeArg, TensorArg
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
"from torch._inductor.triton_heuristics import template",
|
||||
"from torch._inductor.utils import instance_descriptor",
|
||||
"from torch._inductor import triton_helpers",
|
||||
TritonKernel.gen_attr_descriptor_import(),
|
||||
"",
|
||||
self.jit_line(),
|
||||
f"def {self.kernel_name}({', '.join(arg_defs)}):",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import torch
|
|||
|
||||
import torch.autograd.profiler as autograd_profiler
|
||||
from torch._dynamo.device_interface import get_interface_for_device
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._dynamo.utils import dynamo_timed, get_first_attr
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
from . import config
|
||||
|
|
@ -44,11 +44,17 @@ if has_triton_package():
|
|||
from triton import Config
|
||||
from triton.runtime.autotuner import OutOfResources
|
||||
from triton.runtime.jit import KernelInterface
|
||||
|
||||
try:
|
||||
from triton.compiler.compiler import ASTSource
|
||||
except ImportError:
|
||||
ASTSource = None
|
||||
else:
|
||||
Config = object
|
||||
triton = None
|
||||
KernelInterface = object
|
||||
OutOfResources = object
|
||||
ASTSource = None
|
||||
|
||||
|
||||
_NUM_THREADS_PER_WARP = 32
|
||||
|
|
@ -286,14 +292,44 @@ class CachingAutotuner(KernelInterface):
|
|||
# Setting device_type="hip" required on ROCm to pass down to triton
|
||||
compile_meta["device_type"] = "cuda" if torch.version.hip is None else "hip"
|
||||
|
||||
device_type = compile_meta["device_type"]
|
||||
if warm_cache_only_with_cc:
|
||||
cc = warm_cache_only_with_cc
|
||||
else:
|
||||
device_id = compile_meta["device"]
|
||||
device_interface = get_interface_for_device(device_type)
|
||||
device = torch.device(device_type, device_id)
|
||||
cc = device_interface.get_compute_capability(device)
|
||||
|
||||
compile_meta["cc"] = cc
|
||||
|
||||
if ASTSource:
|
||||
compile_args = (
|
||||
ASTSource(
|
||||
self.fn,
|
||||
compile_meta["signature"],
|
||||
compile_meta["constants"],
|
||||
compile_meta["configs"][0],
|
||||
),
|
||||
)
|
||||
|
||||
target = (device_type, cc)
|
||||
options = {
|
||||
"num_warps": compile_meta["num_warps"],
|
||||
"num_stages": compile_meta["num_stages"],
|
||||
"debug": compile_meta["debug"],
|
||||
}
|
||||
compile_kwargs = {
|
||||
"target": target,
|
||||
"options": options,
|
||||
}
|
||||
else:
|
||||
compile_args = (self.fn,)
|
||||
compile_kwargs = compile_meta
|
||||
|
||||
if warm_cache_only_with_cc:
|
||||
return (
|
||||
triton.compile(
|
||||
self.fn,
|
||||
warm_cache_only=True,
|
||||
cc=warm_cache_only_with_cc,
|
||||
**compile_meta,
|
||||
),
|
||||
triton.compile(*compile_args, **compile_kwargs),
|
||||
None,
|
||||
)
|
||||
|
||||
|
|
@ -301,10 +337,8 @@ class CachingAutotuner(KernelInterface):
|
|||
with torch.cuda.device(compile_meta["device"]):
|
||||
# need to initialize context
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
binary = triton.compile(
|
||||
self.fn,
|
||||
**compile_meta,
|
||||
)
|
||||
|
||||
binary = triton.compile(*compile_args, **compile_kwargs)
|
||||
binary._init_handles()
|
||||
|
||||
call_args = [
|
||||
|
|
@ -321,6 +355,14 @@ class CachingAutotuner(KernelInterface):
|
|||
"set_device": torch.cuda.set_device,
|
||||
"current_device": torch.cuda.current_device,
|
||||
}
|
||||
|
||||
scope["runner"] = get_first_attr(binary, "run", "c_wrapper")
|
||||
scope["function"] = get_first_attr(binary, "function", "cu_function")
|
||||
cluster_dims = get_first_attr(binary, "cluster_dims", "clusterDims")
|
||||
scope["cta_args"] = (
|
||||
(binary.num_ctas, *cluster_dims) if hasattr(binary, "num_ctas") else ()
|
||||
)
|
||||
|
||||
exec(
|
||||
f"""
|
||||
def launcher({', '.join(def_args)}, grid, stream):
|
||||
|
|
@ -329,15 +371,10 @@ class CachingAutotuner(KernelInterface):
|
|||
else:
|
||||
grid_0, grid_1, grid_2 = grid
|
||||
|
||||
if hasattr(bin, "num_ctas"):
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps,
|
||||
bin.num_ctas, *bin.clusterDims, bin.shared,
|
||||
stream, bin.cu_function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
else:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared,
|
||||
stream, bin.cu_function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
runner(grid_0, grid_1, grid_2, bin.num_warps,
|
||||
*cta_args, bin.shared,
|
||||
stream, function, None, None, None,
|
||||
{', '.join(call_args)})
|
||||
return bin
|
||||
""".lstrip(),
|
||||
scope,
|
||||
|
|
|
|||
|
|
@ -609,11 +609,16 @@ def has_incompatible_cudagraph_ops(gm):
|
|||
return False
|
||||
|
||||
|
||||
instance_descriptor = collections.namedtuple(
|
||||
"instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[tuple(), tuple(), tuple(), tuple()],
|
||||
)
|
||||
try:
|
||||
from triton.compiler.compiler import AttrsDescriptor as instance_descriptor
|
||||
except ImportError:
|
||||
# To support older version of triton which does not have AttrsDescriptor
|
||||
# class
|
||||
instance_descriptor = collections.namedtuple( # type: ignore[no-redef]
|
||||
"instance_descriptor",
|
||||
["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"],
|
||||
defaults=[tuple(), tuple(), tuple(), tuple()],
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user