Revert "[triton] Warp specialization support in torchinductor (#148503)"

This reverts commit 36183215e8.

Reverted https://github.com/pytorch/pytorch/pull/148503 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/148503#issuecomment-2758590645))
This commit is contained in:
PyTorch MergeBot 2025-03-27 16:06:42 +00:00
parent af7719a2fa
commit efc975feb2
8 changed files with 27 additions and 182 deletions

View File

@ -38,9 +38,10 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MUL
from torch.testing._internal.common_device_type import (
flex_attention_supported_platform as supported_platform,
)
from torch.testing._internal.common_utils import IS_MACOS, IS_FBCODE
from torch.testing._internal.common_utils import IS_MACOS
from torch.utils._triton import has_triton
# Use this decorator only when hitting Triton bugs on H100
running_on_a100_only = skipUnless(
(torch.cuda.is_available() and has_triton())
@ -3682,7 +3683,6 @@ class GraphModule(torch.nn.Module):
attn_output = mod(q, k, v, mask)
self.assertEqual(attn_output.device, torch.device("cuda:1"))
@supported_platform
def test_validate_small_embedding_size_error_message(self):
# eager support for small embedding size
@ -3708,45 +3708,6 @@ class GraphModule(torch.nn.Module):
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
compiled_fa = torch.compile(flex_attention)
@unittest.skipIf(not has_triton() and not IS_FBCODE, reason="FBCODE Triton is required for this test")
def test_triton_template_warp_specialization(self):
make_tensor = lambda: torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
q, k, v = make_tensor(), make_tensor(), make_tensor()
flex_compiled = torch.compile(flex_attention, fullgraph=True)
positional_args = (q, k, v)
keyword_args = {
"kernel_options": {
"num_warps": 4,
"num_consumer_groups": 0,
"num_buffers_warp_spec": 0,
}
}
# Check if kernel code contains warp specialization parameters
_, kernel_code = run_and_get_code(
flex_compiled,
*positional_args,
**keyword_args,
)
assert kernel_code is not None, "Failed to retrieve compiled kernel code"
assert (
"num_consumer_groups" in kernel_code[0]
), "num_consumer_groups missing in kernel definition"
assert (
"num_buffers_warp_spec" in kernel_code[0]
), "num_buffers_warp_spec missing in kernel definition"
# Validate correctness
C1 = flex_compiled(q, k, v)
C2 = flex_attention(q, k, v)
assert torch.allclose(
C1, C2, atol=1e-2, rtol=1e-2
), "Warp specialized kernel result differs from reference"
class TestBlockMask(InductorTestCase):
@supported_platform

View File

@ -357,8 +357,6 @@ class TestSelectAlgorithm(TestCase):
extra_args=None,
num_stages=None,
num_warps=None,
num_consumer_groups=None,
num_buffers_warp_spec=None,
input_tensor_meta=None,
output_tensor_meta=None,
)

View File

@ -2,13 +2,11 @@
import sys
import unittest
from unittest.mock import MagicMock, patch
import torch
from torch._dynamo.testing import rand_strided
from torch._inductor.runtime.triton_compat import Config
from torch._inductor.utils import clone_preserve_strides
from torch.testing._internal.common_utils import IS_FBCODE, IS_LINUX, skipIfXpu
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
from torch.testing._internal.inductor_utils import (
GPU_TYPE,
HAS_GPU,
@ -36,7 +34,6 @@ from torch._inductor.runtime.triton_helpers import math as tl_math
from torch._inductor.runtime.triton_heuristics import (
autotune_hints_to_configs,
CachingAutotuner,
template,
triton_config,
)
from torch._inductor.test_case import run_tests, TestCase
@ -189,30 +186,6 @@ class TestTritonHeuristics(TestCase):
self.assertTrue(8 in seen_num_elements_per_warp)
@unittest.skipIf(not IS_FBCODE, "FBCODE Triton is required for this test")
def test_template_function_ws(self):
triton_meta = {"device": MagicMock()}
num_stages = 2
num_warps = 4
num_consumer_groups = 3
num_buffers_warp_spec = 5
with patch(
"torch._inductor.runtime.triton_heuristics.cached_autotune"
) as mock_cached_autotune:
template(
num_stages=num_stages,
num_warps=num_warps,
triton_meta=triton_meta,
num_consumer_groups=num_consumer_groups,
num_buffers_warp_spec=num_buffers_warp_spec,
)
mock_cached_autotune.assert_called_once()
configs = mock_cached_autotune.call_args[0][1]
self.assertEqual(configs[0].num_consumer_groups, num_consumer_groups)
self.assertEqual(configs[0].num_buffers_warp_spec, num_buffers_warp_spec)
class TestArgumentCloneAndRestore(TestCase):
# Our tensor is large enough. If a unexpected copy happens, the
# peak memory increase should be larger than tolerance and the test

View File

@ -1473,7 +1473,7 @@ class TritonHOPifier:
new_var = type(variable)(new_kernel, None, variable.grid)
return self.call_triton_kernel(new_var, args, kwargs, tx)
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas", "num_consumer_groups", "num_buffers_warp_spec"}
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
# move special config names to configs out of kwargs
special_kwargs = {}

View File

@ -638,8 +638,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
module_cache_key: str,
num_stages: int,
num_warps: int,
num_consumer_groups: int = 0,
num_buffers_warp_spec: int = 0,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
kpack: int = 0, # ROCm specific gemm paramete
@ -650,8 +648,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
self.module_cache_key = module_cache_key
self.num_stages = num_stages
self.num_warps = num_warps
self.num_consumer_groups = num_consumer_groups
self.num_buffers_warp_spec = num_buffers_warp_spec
self.matrix_instr_nonkdim = matrix_instr_nonkdim
self.waves_per_eu = waves_per_eu
self.kpack = kpack

View File

@ -21,7 +21,7 @@ from ..remote_cache import (
RemoteCacheJsonSerde,
)
from .triton_compat import Config
from .. import config as inductor_config
if TYPE_CHECKING:
from ..remote_cache import Sample
@ -207,11 +207,6 @@ class AutotuneCache:
"found_by_coordesc": found_by_coordesc,
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
}
if inductor_config.is_fbcode():
data.update({
"num_consumer_groups": getattr(config, 'num_consumer_groups', 0),
"num_buffers_warp_spec": getattr(config, 'num_buffers_warp_spec', 0),
})
if local_cache := self.local_cache:
cache, key = local_cache
@ -469,22 +464,7 @@ def _load_cached_autotuning(
):
num_warps = best_config.pop("num_warps")
num_stages = best_config.pop("num_stages")
# Extract common arguments
config_args = {
"num_warps": num_warps,
"num_stages": num_stages,
}
# Conditionally add arguments based on inductor_config.is_fbcode()
if inductor_config.is_fbcode():
config_args.update({
"num_consumer_groups": best_config.pop("num_consumer_groups", 0),
"num_buffers_warp_spec": best_config.pop("num_buffers_warp_spec", 0),
})
# Create the triton_config with the appropriate arguments
triton_config = Config(best_config, **config_args)
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
triton_config.found_by_coordesc = True
return triton_config

View File

@ -76,7 +76,6 @@ from .triton_compat import (
PTXASError,
triton,
)
from .. import config as inductor_config
class NoTritonConfigsError(RuntimeError):
@ -170,9 +169,6 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
call_kwargs.update(launcher.config.kwargs)
call_kwargs["num_warps"] = launcher.config.num_warps
call_kwargs["num_stages"] = launcher.config.num_stages
if inductor_config.is_fbcode():
call_kwargs["num_consumer_groups"] = getattr(launcher.config, 'num_consumer_groups', 0)
call_kwargs["num_buffers_warp_spec"] = getattr(launcher.config, 'num_buffers_warp_spec', 0)
args_str = [*call_args]
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
args_str = ", ".join(args_str)
@ -513,9 +509,6 @@ class CachingAutotuner(KernelInterface):
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
compile_meta["num_warps"] = cfg.num_warps
compile_meta["num_stages"] = cfg.num_stages
if inductor_config.is_fbcode():
compile_meta["num_consumer_groups"] = getattr(cfg, 'num_consumer_groups', 0)
compile_meta["num_buffers_warp_spec"] = getattr(cfg, 'num_buffers_warp_spec', 0)
compile_meta["debug"] = self.inductor_meta.get(
"assert_indirect_indexing", True
) and not self.inductor_meta.get("is_hip", False)
@ -562,11 +555,6 @@ class CachingAutotuner(KernelInterface):
"debug": compile_meta["debug"],
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
}
if inductor_config.is_fbcode():
options.update({
"num_consumer_groups": compile_meta["num_consumer_groups"],
"num_buffers_warp_spec": compile_meta["num_buffers_warp_spec"],
})
if self.device_props.type == "hip":
if "waves_per_eu" in compile_meta:
options["waves_per_eu"] = compile_meta["waves_per_eu"]
@ -2338,28 +2326,13 @@ def split_scan(
)
def template(num_stages, num_warps,
triton_meta,
num_consumer_groups = 0,
num_buffers_warp_spec = 0, filename=None, inductor_meta=None):
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
"""
Compile a triton template
"""
# Prepare the base configuration
config_args = {
"num_stages": num_stages,
"num_warps": num_warps,
}
# Conditionally add arguments based on inductor_config.is_fbcode()
if inductor_config.is_fbcode():
config_args.update({
"num_consumer_groups": num_consumer_groups,
"num_buffers_warp_spec": num_buffers_warp_spec,
})
return cached_autotune(
None,
[triton.Config({}, **config_args)],
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
triton_meta=triton_meta,
inductor_meta=inductor_meta,
heuristic_type=HeuristicType.TEMPLATE,
@ -2370,7 +2343,7 @@ def template(num_stages, num_warps,
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
"""Extract triton.Config options that should become kwargs"""
popped = {}
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg", "num_consumer_groups", "num_buffers_warp_spec"):
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
val = config.pop(key, None)
if val is not None:
popped[key] = val
@ -2378,18 +2351,11 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
def config_to_dict(config: Config) -> dict[str, Any]:
config_dict = {
return {
**config.kwargs,
"num_warps": config.num_warps,
"num_stages": config.num_stages,
}
if inductor_config.is_fbcode():
config_dict.update({
"num_consumer_groups": getattr(config, 'num_consumer_groups', 0),
"num_buffers_warp_spec": getattr(config, 'num_buffers_warp_spec', 0),
})
return config_dict
def config_from_dict(config: dict[str, Any]) -> Config:

View File

@ -293,8 +293,6 @@ class TritonTemplateKernel(TritonKernel):
grid_fn,
meta,
call_sizes,
num_consumer_groups=0,
num_buffers_warp_spec=0,
use_jit=False,
prefix_args=0,
suffix_args=0,
@ -318,8 +316,6 @@ class TritonTemplateKernel(TritonKernel):
self.use_jit = use_jit
self.num_stages = num_stages
self.num_warps = num_warps
self.num_consumer_groups = num_consumer_groups
self.num_buffers_warp_spec = num_buffers_warp_spec
self.grid_fn = grid_fn
self.meta = meta
self.call_sizes = call_sizes
@ -461,24 +457,12 @@ class TritonTemplateKernel(TritonKernel):
if config.profile_bandwidth or config.benchmark_kernel:
num_gb = self.estimate_kernel_num_bytes() / 1e9
inductor_meta["kernel_num_gb"] = num_gb
template_args = f"""
return f"""
@triton_heuristics.template(
num_stages={self.num_stages},
num_warps={self.num_warps},
triton_meta={triton_meta!r},
inductor_meta={inductor_meta!r},
"""
# Conditionally add arguments based on iis_fbcode()
if config.is_fbcode():
template_args += f"""
num_consumer_groups={self.num_consumer_groups},
num_buffers_warp_spec={self.num_buffers_warp_spec},
"""
return f"""
@triton_heuristics.template(
{template_args}
)
@triton.jit
"""
@ -1086,8 +1070,6 @@ class TritonTemplate(KernelTemplate):
layout,
num_stages,
num_warps,
num_consumer_groups=0,
num_buffers_warp_spec=0,
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
@ -1153,11 +1135,6 @@ class TritonTemplate(KernelTemplate):
"epilogue_fn": epilogue_fn,
"subgraphs": subgraphs,
}
if config.is_fbcode():
kernel_options.update({
"num_consumer_groups": num_consumer_groups,
"num_buffers_warp_spec": num_buffers_warp_spec,
})
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
@ -1179,23 +1156,19 @@ class TritonTemplate(KernelTemplate):
return None
if self.debug:
print("Generated Code:\n", code)
extra_parts = [
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
]
extra_parts.extend([
extra = (
"-".join(
[
*[
f"{kwarg}={repr(kwargs[kwarg])}"
for kwarg in sorted(kwargs.keys())
],
f"num_stages={num_stages}",
f"num_warps={num_warps}",
])
# Conditionally add arguments based on inductor_config.is_fbcode()
if config.is_fbcode():
extra_parts.extend([
f"num_consumer_groups={num_consumer_groups}",
f"num_buffers_warp_spec={num_buffers_warp_spec}",
])
extra = "-".join(extra_parts) + "-"
]
)
+ "-"
)
mod = PyCodeCache.load(code, extra)
input_call_args = tuple(kernel.args.input_buffers.keys())
@ -1251,8 +1224,6 @@ class TritonTemplate(KernelTemplate):
extra_args=[*extra_args, *grid],
num_stages=num_stages,
num_warps=num_warps,
num_consumer_groups=num_consumer_groups,
num_buffers_warp_spec=num_buffers_warp_spec,
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
waves_per_eu=kwargs.get("waves_per_eu", 0),
kpack=kwargs.get("kpack", 2),