mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[triton] Warp specialization support in torchinductor (#148503)"
This reverts commit 36183215e8.
Reverted https://github.com/pytorch/pytorch/pull/148503 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/148503#issuecomment-2758590645))
This commit is contained in:
parent
af7719a2fa
commit
efc975feb2
|
|
@ -38,9 +38,10 @@ from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MUL
|
|||
from torch.testing._internal.common_device_type import (
|
||||
flex_attention_supported_platform as supported_platform,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_MACOS, IS_FBCODE
|
||||
from torch.testing._internal.common_utils import IS_MACOS
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
# Use this decorator only when hitting Triton bugs on H100
|
||||
running_on_a100_only = skipUnless(
|
||||
(torch.cuda.is_available() and has_triton())
|
||||
|
|
@ -3682,7 +3683,6 @@ class GraphModule(torch.nn.Module):
|
|||
attn_output = mod(q, k, v, mask)
|
||||
self.assertEqual(attn_output.device, torch.device("cuda:1"))
|
||||
|
||||
|
||||
@supported_platform
|
||||
def test_validate_small_embedding_size_error_message(self):
|
||||
# eager support for small embedding size
|
||||
|
|
@ -3708,45 +3708,6 @@ class GraphModule(torch.nn.Module):
|
|||
q, k, v = [torch.randn(2, 2, 128, 16, device="cuda") for _ in range(3)]
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
|
||||
@unittest.skipIf(not has_triton() and not IS_FBCODE, reason="FBCODE Triton is required for this test")
|
||||
def test_triton_template_warp_specialization(self):
|
||||
|
||||
make_tensor = lambda: torch.rand(4, 16, 4096, 64, device="cuda", dtype=torch.bfloat16)
|
||||
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
||||
flex_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
|
||||
positional_args = (q, k, v)
|
||||
keyword_args = {
|
||||
"kernel_options": {
|
||||
"num_warps": 4,
|
||||
"num_consumer_groups": 0,
|
||||
"num_buffers_warp_spec": 0,
|
||||
}
|
||||
}
|
||||
|
||||
# Check if kernel code contains warp specialization parameters
|
||||
_, kernel_code = run_and_get_code(
|
||||
flex_compiled,
|
||||
*positional_args,
|
||||
**keyword_args,
|
||||
)
|
||||
assert kernel_code is not None, "Failed to retrieve compiled kernel code"
|
||||
assert (
|
||||
"num_consumer_groups" in kernel_code[0]
|
||||
), "num_consumer_groups missing in kernel definition"
|
||||
assert (
|
||||
"num_buffers_warp_spec" in kernel_code[0]
|
||||
), "num_buffers_warp_spec missing in kernel definition"
|
||||
|
||||
# Validate correctness
|
||||
C1 = flex_compiled(q, k, v)
|
||||
C2 = flex_attention(q, k, v)
|
||||
|
||||
assert torch.allclose(
|
||||
C1, C2, atol=1e-2, rtol=1e-2
|
||||
), "Warp specialized kernel result differs from reference"
|
||||
|
||||
|
||||
class TestBlockMask(InductorTestCase):
|
||||
@supported_platform
|
||||
|
|
|
|||
|
|
@ -357,8 +357,6 @@ class TestSelectAlgorithm(TestCase):
|
|||
extra_args=None,
|
||||
num_stages=None,
|
||||
num_warps=None,
|
||||
num_consumer_groups=None,
|
||||
num_buffers_warp_spec=None,
|
||||
input_tensor_meta=None,
|
||||
output_tensor_meta=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,11 @@
|
|||
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._inductor.runtime.triton_compat import Config
|
||||
from torch._inductor.utils import clone_preserve_strides
|
||||
from torch.testing._internal.common_utils import IS_FBCODE, IS_LINUX, skipIfXpu
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
|
||||
from torch.testing._internal.inductor_utils import (
|
||||
GPU_TYPE,
|
||||
HAS_GPU,
|
||||
|
|
@ -36,7 +34,6 @@ from torch._inductor.runtime.triton_helpers import math as tl_math
|
|||
from torch._inductor.runtime.triton_heuristics import (
|
||||
autotune_hints_to_configs,
|
||||
CachingAutotuner,
|
||||
template,
|
||||
triton_config,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
|
|
@ -189,30 +186,6 @@ class TestTritonHeuristics(TestCase):
|
|||
self.assertTrue(8 in seen_num_elements_per_warp)
|
||||
|
||||
|
||||
@unittest.skipIf(not IS_FBCODE, "FBCODE Triton is required for this test")
|
||||
def test_template_function_ws(self):
|
||||
triton_meta = {"device": MagicMock()}
|
||||
num_stages = 2
|
||||
num_warps = 4
|
||||
num_consumer_groups = 3
|
||||
num_buffers_warp_spec = 5
|
||||
|
||||
with patch(
|
||||
"torch._inductor.runtime.triton_heuristics.cached_autotune"
|
||||
) as mock_cached_autotune:
|
||||
template(
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
triton_meta=triton_meta,
|
||||
num_consumer_groups=num_consumer_groups,
|
||||
num_buffers_warp_spec=num_buffers_warp_spec,
|
||||
)
|
||||
mock_cached_autotune.assert_called_once()
|
||||
configs = mock_cached_autotune.call_args[0][1]
|
||||
self.assertEqual(configs[0].num_consumer_groups, num_consumer_groups)
|
||||
self.assertEqual(configs[0].num_buffers_warp_spec, num_buffers_warp_spec)
|
||||
|
||||
|
||||
class TestArgumentCloneAndRestore(TestCase):
|
||||
# Our tensor is large enough. If a unexpected copy happens, the
|
||||
# peak memory increase should be larger than tolerance and the test
|
||||
|
|
|
|||
|
|
@ -1473,7 +1473,7 @@ class TritonHOPifier:
|
|||
new_var = type(variable)(new_kernel, None, variable.grid)
|
||||
return self.call_triton_kernel(new_var, args, kwargs, tx)
|
||||
|
||||
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas", "num_consumer_groups", "num_buffers_warp_spec"}
|
||||
SPECIAL_CONFIG_NAMES = {"num_warps", "num_stages", "num_ctas"}
|
||||
|
||||
# move special config names to configs out of kwargs
|
||||
special_kwargs = {}
|
||||
|
|
|
|||
|
|
@ -638,8 +638,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
module_cache_key: str,
|
||||
num_stages: int,
|
||||
num_warps: int,
|
||||
num_consumer_groups: int = 0,
|
||||
num_buffers_warp_spec: int = 0,
|
||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||
waves_per_eu: int = 0, # only used for hip to schedule waves per execution unit
|
||||
kpack: int = 0, # ROCm specific gemm paramete
|
||||
|
|
@ -650,8 +648,6 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||
self.module_cache_key = module_cache_key
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.num_consumer_groups = num_consumer_groups
|
||||
self.num_buffers_warp_spec = num_buffers_warp_spec
|
||||
self.matrix_instr_nonkdim = matrix_instr_nonkdim
|
||||
self.waves_per_eu = waves_per_eu
|
||||
self.kpack = kpack
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from ..remote_cache import (
|
|||
RemoteCacheJsonSerde,
|
||||
)
|
||||
from .triton_compat import Config
|
||||
from .. import config as inductor_config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..remote_cache import Sample
|
||||
|
|
@ -207,11 +207,6 @@ class AutotuneCache:
|
|||
"found_by_coordesc": found_by_coordesc,
|
||||
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
|
||||
}
|
||||
if inductor_config.is_fbcode():
|
||||
data.update({
|
||||
"num_consumer_groups": getattr(config, 'num_consumer_groups', 0),
|
||||
"num_buffers_warp_spec": getattr(config, 'num_buffers_warp_spec', 0),
|
||||
})
|
||||
|
||||
if local_cache := self.local_cache:
|
||||
cache, key = local_cache
|
||||
|
|
@ -469,22 +464,7 @@ def _load_cached_autotuning(
|
|||
):
|
||||
num_warps = best_config.pop("num_warps")
|
||||
num_stages = best_config.pop("num_stages")
|
||||
|
||||
# Extract common arguments
|
||||
config_args = {
|
||||
"num_warps": num_warps,
|
||||
"num_stages": num_stages,
|
||||
}
|
||||
|
||||
# Conditionally add arguments based on inductor_config.is_fbcode()
|
||||
if inductor_config.is_fbcode():
|
||||
config_args.update({
|
||||
"num_consumer_groups": best_config.pop("num_consumer_groups", 0),
|
||||
"num_buffers_warp_spec": best_config.pop("num_buffers_warp_spec", 0),
|
||||
})
|
||||
|
||||
# Create the triton_config with the appropriate arguments
|
||||
triton_config = Config(best_config, **config_args)
|
||||
triton_config = Config(best_config, num_warps=num_warps, num_stages=num_stages)
|
||||
triton_config.found_by_coordesc = True
|
||||
return triton_config
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,6 @@ from .triton_compat import (
|
|||
PTXASError,
|
||||
triton,
|
||||
)
|
||||
from .. import config as inductor_config
|
||||
|
||||
|
||||
class NoTritonConfigsError(RuntimeError):
|
||||
|
|
@ -170,9 +169,6 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
|
|||
call_kwargs.update(launcher.config.kwargs)
|
||||
call_kwargs["num_warps"] = launcher.config.num_warps
|
||||
call_kwargs["num_stages"] = launcher.config.num_stages
|
||||
if inductor_config.is_fbcode():
|
||||
call_kwargs["num_consumer_groups"] = getattr(launcher.config, 'num_consumer_groups', 0)
|
||||
call_kwargs["num_buffers_warp_spec"] = getattr(launcher.config, 'num_buffers_warp_spec', 0)
|
||||
args_str = [*call_args]
|
||||
args_str.extend(f"{k}={v}" for k, v in call_kwargs.items())
|
||||
args_str = ", ".join(args_str)
|
||||
|
|
@ -513,9 +509,6 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_meta["constants"][arg_name] = getattr(cfg, arg_name)
|
||||
compile_meta["num_warps"] = cfg.num_warps
|
||||
compile_meta["num_stages"] = cfg.num_stages
|
||||
if inductor_config.is_fbcode():
|
||||
compile_meta["num_consumer_groups"] = getattr(cfg, 'num_consumer_groups', 0)
|
||||
compile_meta["num_buffers_warp_spec"] = getattr(cfg, 'num_buffers_warp_spec', 0)
|
||||
compile_meta["debug"] = self.inductor_meta.get(
|
||||
"assert_indirect_indexing", True
|
||||
) and not self.inductor_meta.get("is_hip", False)
|
||||
|
|
@ -562,11 +555,6 @@ class CachingAutotuner(KernelInterface):
|
|||
"debug": compile_meta["debug"],
|
||||
"sanitize_overflow": False, # turn off additional asserts added for overflow checks
|
||||
}
|
||||
if inductor_config.is_fbcode():
|
||||
options.update({
|
||||
"num_consumer_groups": compile_meta["num_consumer_groups"],
|
||||
"num_buffers_warp_spec": compile_meta["num_buffers_warp_spec"],
|
||||
})
|
||||
if self.device_props.type == "hip":
|
||||
if "waves_per_eu" in compile_meta:
|
||||
options["waves_per_eu"] = compile_meta["waves_per_eu"]
|
||||
|
|
@ -2338,28 +2326,13 @@ def split_scan(
|
|||
)
|
||||
|
||||
|
||||
def template(num_stages, num_warps,
|
||||
triton_meta,
|
||||
num_consumer_groups = 0,
|
||||
num_buffers_warp_spec = 0, filename=None, inductor_meta=None):
|
||||
def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=None):
|
||||
"""
|
||||
Compile a triton template
|
||||
"""
|
||||
# Prepare the base configuration
|
||||
config_args = {
|
||||
"num_stages": num_stages,
|
||||
"num_warps": num_warps,
|
||||
}
|
||||
|
||||
# Conditionally add arguments based on inductor_config.is_fbcode()
|
||||
if inductor_config.is_fbcode():
|
||||
config_args.update({
|
||||
"num_consumer_groups": num_consumer_groups,
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
})
|
||||
return cached_autotune(
|
||||
None,
|
||||
[triton.Config({}, **config_args)],
|
||||
[triton.Config({}, num_stages=num_stages, num_warps=num_warps)],
|
||||
triton_meta=triton_meta,
|
||||
inductor_meta=inductor_meta,
|
||||
heuristic_type=HeuristicType.TEMPLATE,
|
||||
|
|
@ -2370,7 +2343,7 @@ def template(num_stages, num_warps,
|
|||
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract triton.Config options that should become kwargs"""
|
||||
popped = {}
|
||||
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg", "num_consumer_groups", "num_buffers_warp_spec"):
|
||||
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
|
||||
val = config.pop(key, None)
|
||||
if val is not None:
|
||||
popped[key] = val
|
||||
|
|
@ -2378,18 +2351,11 @@ def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
|||
|
||||
|
||||
def config_to_dict(config: Config) -> dict[str, Any]:
|
||||
config_dict = {
|
||||
return {
|
||||
**config.kwargs,
|
||||
"num_warps": config.num_warps,
|
||||
"num_stages": config.num_stages,
|
||||
|
||||
}
|
||||
if inductor_config.is_fbcode():
|
||||
config_dict.update({
|
||||
"num_consumer_groups": getattr(config, 'num_consumer_groups', 0),
|
||||
"num_buffers_warp_spec": getattr(config, 'num_buffers_warp_spec', 0),
|
||||
})
|
||||
return config_dict
|
||||
|
||||
|
||||
def config_from_dict(config: dict[str, Any]) -> Config:
|
||||
|
|
|
|||
|
|
@ -293,8 +293,6 @@ class TritonTemplateKernel(TritonKernel):
|
|||
grid_fn,
|
||||
meta,
|
||||
call_sizes,
|
||||
num_consumer_groups=0,
|
||||
num_buffers_warp_spec=0,
|
||||
use_jit=False,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
|
|
@ -318,8 +316,6 @@ class TritonTemplateKernel(TritonKernel):
|
|||
self.use_jit = use_jit
|
||||
self.num_stages = num_stages
|
||||
self.num_warps = num_warps
|
||||
self.num_consumer_groups = num_consumer_groups
|
||||
self.num_buffers_warp_spec = num_buffers_warp_spec
|
||||
self.grid_fn = grid_fn
|
||||
self.meta = meta
|
||||
self.call_sizes = call_sizes
|
||||
|
|
@ -461,24 +457,12 @@ class TritonTemplateKernel(TritonKernel):
|
|||
if config.profile_bandwidth or config.benchmark_kernel:
|
||||
num_gb = self.estimate_kernel_num_bytes() / 1e9
|
||||
inductor_meta["kernel_num_gb"] = num_gb
|
||||
|
||||
template_args = f"""
|
||||
num_stages={self.num_stages},
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
"""
|
||||
|
||||
# Conditionally add arguments based on iis_fbcode()
|
||||
if config.is_fbcode():
|
||||
template_args += f"""
|
||||
num_consumer_groups={self.num_consumer_groups},
|
||||
num_buffers_warp_spec={self.num_buffers_warp_spec},
|
||||
"""
|
||||
|
||||
return f"""
|
||||
@triton_heuristics.template(
|
||||
{template_args}
|
||||
num_stages={self.num_stages},
|
||||
num_warps={self.num_warps},
|
||||
triton_meta={triton_meta!r},
|
||||
inductor_meta={inductor_meta!r},
|
||||
)
|
||||
@triton.jit
|
||||
"""
|
||||
|
|
@ -1086,8 +1070,6 @@ class TritonTemplate(KernelTemplate):
|
|||
layout,
|
||||
num_stages,
|
||||
num_warps,
|
||||
num_consumer_groups=0,
|
||||
num_buffers_warp_spec=0,
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
epilogue_fn=identity,
|
||||
|
|
@ -1153,11 +1135,6 @@ class TritonTemplate(KernelTemplate):
|
|||
"epilogue_fn": epilogue_fn,
|
||||
"subgraphs": subgraphs,
|
||||
}
|
||||
if config.is_fbcode():
|
||||
kernel_options.update({
|
||||
"num_consumer_groups": num_consumer_groups,
|
||||
"num_buffers_warp_spec": num_buffers_warp_spec,
|
||||
})
|
||||
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
|
||||
|
|
@ -1179,23 +1156,19 @@ class TritonTemplate(KernelTemplate):
|
|||
return None
|
||||
if self.debug:
|
||||
print("Generated Code:\n", code)
|
||||
extra_parts = [
|
||||
f"{kwarg}={repr(kwargs[kwarg])}" for kwarg in sorted(kwargs.keys())
|
||||
]
|
||||
|
||||
extra_parts.extend([
|
||||
f"num_stages={num_stages}",
|
||||
f"num_warps={num_warps}",
|
||||
])
|
||||
|
||||
# Conditionally add arguments based on inductor_config.is_fbcode()
|
||||
if config.is_fbcode():
|
||||
extra_parts.extend([
|
||||
f"num_consumer_groups={num_consumer_groups}",
|
||||
f"num_buffers_warp_spec={num_buffers_warp_spec}",
|
||||
])
|
||||
|
||||
extra = "-".join(extra_parts) + "-"
|
||||
extra = (
|
||||
"-".join(
|
||||
[
|
||||
*[
|
||||
f"{kwarg}={repr(kwargs[kwarg])}"
|
||||
for kwarg in sorted(kwargs.keys())
|
||||
],
|
||||
f"num_stages={num_stages}",
|
||||
f"num_warps={num_warps}",
|
||||
]
|
||||
)
|
||||
+ "-"
|
||||
)
|
||||
mod = PyCodeCache.load(code, extra)
|
||||
|
||||
input_call_args = tuple(kernel.args.input_buffers.keys())
|
||||
|
|
@ -1251,8 +1224,6 @@ class TritonTemplate(KernelTemplate):
|
|||
extra_args=[*extra_args, *grid],
|
||||
num_stages=num_stages,
|
||||
num_warps=num_warps,
|
||||
num_consumer_groups=num_consumer_groups,
|
||||
num_buffers_warp_spec=num_buffers_warp_spec,
|
||||
matrix_instr_nonkdim=kwargs.get("matrix_instr_nonkdim", 0),
|
||||
waves_per_eu=kwargs.get("waves_per_eu", 0),
|
||||
kpack=kwargs.get("kpack", 2),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user