mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Enable Grouped GEMM Template (#143796)
**Summary** Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012 - Support flexible number of GEMMs - Share activation across GEMMs - The Grouped GEMM Template supports independent activations - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs - Each GEMM can have a unique weight but same sizes - Each GEMM can have a unique bias or None - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR - Each GEMM have its own epilogues - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear ``` **Example** Here is the example and generated code ``` batch_size = 4 in_features = 512 out_features = 1024 dtype = torch.bfloat16 class M(torch.nn.Module): def __init__(self, bias): super().__init__() self.linear0 = torch.nn.Linear(in_features, out_features, bias=False) self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) def forward(self, x): return self.linear0(x), self.linear1(x) if __name__ == "__main__": with torch.no_grad(): input = torch.randn(batch_size, in_features, dtype=dtype) m = M(bias=bias).to(dtype=dtype).eval() cm = torch.compile(m) act_res = cm(input) ``` Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py **Next Step** - Support Epilogue fusion Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
35b46a75f1
commit
25de671ea8
|
|
@ -207,7 +207,12 @@ if RUN_CPU:
|
||||||
*[
|
*[
|
||||||
BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
||||||
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
||||||
if func.startswith("test_linear_with_pointwise")
|
if func.startswith(
|
||||||
|
(
|
||||||
|
"test_linear_with_pointwise",
|
||||||
|
"test_grouped_linear",
|
||||||
|
)
|
||||||
|
)
|
||||||
],
|
],
|
||||||
BaseTest("test_polar"),
|
BaseTest("test_polar"),
|
||||||
BaseTest(
|
BaseTest(
|
||||||
|
|
|
||||||
|
|
@ -1683,6 +1683,105 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
||||||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||||
|
|
||||||
|
@inductor_config.patch({"freezing": True})
|
||||||
|
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
|
||||||
|
@patches
|
||||||
|
@torch.no_grad
|
||||||
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||||
|
@parametrize("batch_size", (16,))
|
||||||
|
@parametrize("in_features", (52,))
|
||||||
|
@parametrize("out_features", (32,))
|
||||||
|
@parametrize("gemm_num", (2, 3))
|
||||||
|
def test_grouped_linear_invalid(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
gemm_num,
|
||||||
|
):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self, in_feature, out_feature, gemm_num):
|
||||||
|
super().__init__()
|
||||||
|
self.linears = [
|
||||||
|
torch.nn.Linear(in_feature, out_feature + gemm_idx, bias=False)
|
||||||
|
for gemm_idx in range(gemm_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return [linear(x) for linear in self.linears]
|
||||||
|
|
||||||
|
# each linear has different num of out features, thus invaild grouped gemm
|
||||||
|
dtypes = []
|
||||||
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||||
|
dtypes.append(torch.bfloat16)
|
||||||
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||||
|
dtypes.append(torch.float16)
|
||||||
|
for dtype in dtypes:
|
||||||
|
torch._dynamo.reset()
|
||||||
|
torch._inductor.metrics.reset()
|
||||||
|
counters.clear()
|
||||||
|
mod = M(in_features, out_features, gemm_num).eval()
|
||||||
|
v = torch.randn(batch_size, in_features).to(dtype)
|
||||||
|
with verify(dtype) as (atol, rtol), torch.autocast(
|
||||||
|
device_type="cpu", dtype=dtype
|
||||||
|
), torch.no_grad():
|
||||||
|
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||||
|
# gemm_num independent template instead of grouped gemm template
|
||||||
|
self.assertEqual(
|
||||||
|
counters["inductor"]["select_algorithm_autotune"], gemm_num
|
||||||
|
)
|
||||||
|
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
|
||||||
|
|
||||||
|
@inductor_config.patch({"freezing": True})
|
||||||
|
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
|
||||||
|
@patches
|
||||||
|
@torch.no_grad
|
||||||
|
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||||
|
@parametrize("batch_size", (16,))
|
||||||
|
@parametrize("in_features", (52,))
|
||||||
|
@parametrize("out_features", (32,))
|
||||||
|
@parametrize("input_3d", (False, True))
|
||||||
|
@parametrize("gemm_num", (2, 3))
|
||||||
|
def test_grouped_linear(
|
||||||
|
self,
|
||||||
|
batch_size,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
input_3d,
|
||||||
|
gemm_num,
|
||||||
|
):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def __init__(self, in_feature, out_feature, gemm_num):
|
||||||
|
super().__init__()
|
||||||
|
self.linears = [
|
||||||
|
torch.nn.Linear(in_feature, out_feature, bias=False)
|
||||||
|
for _ in range(gemm_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return [linear(x) for linear in self.linears]
|
||||||
|
|
||||||
|
dtypes = []
|
||||||
|
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||||
|
dtypes.append(torch.bfloat16)
|
||||||
|
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||||
|
dtypes.append(torch.float16)
|
||||||
|
for dtype in dtypes:
|
||||||
|
if dtype == torch.float16 and input_3d:
|
||||||
|
# reduce the number of tests
|
||||||
|
continue
|
||||||
|
torch._dynamo.reset()
|
||||||
|
torch._inductor.metrics.reset()
|
||||||
|
counters.clear()
|
||||||
|
mod = M(in_features, out_features, gemm_num).eval()
|
||||||
|
B = (2, batch_size) if input_3d else (batch_size,)
|
||||||
|
v = torch.randn(*B, in_features).to(dtype)
|
||||||
|
with verify(dtype) as (atol, rtol), torch.autocast(
|
||||||
|
device_type="cpu", dtype=dtype
|
||||||
|
), torch.no_grad():
|
||||||
|
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||||
|
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
|
||||||
|
|
||||||
@inductor_config.patch({"freezing": False})
|
@inductor_config.patch({"freezing": False})
|
||||||
@patches
|
@patches
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
|
|
@ -2031,6 +2130,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
||||||
test_quantized_linear_amx_dynamic_shapes = (
|
test_quantized_linear_amx_dynamic_shapes = (
|
||||||
TestSelectAlgorithm.test_quantized_linear_amx
|
TestSelectAlgorithm.test_quantized_linear_amx
|
||||||
)
|
)
|
||||||
|
test_grouped_linear_dynamic_shapes = TestSelectAlgorithm.test_grouped_linear
|
||||||
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
|
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
|
||||||
test_linear_cache_blocking_dynamic_shapes = (
|
test_linear_cache_blocking_dynamic_shapes = (
|
||||||
TestSelectAlgorithm.test_linear_cache_blocking
|
TestSelectAlgorithm.test_linear_cache_blocking
|
||||||
|
|
|
||||||
|
|
@ -500,7 +500,13 @@ class BenchmarkRequest:
|
||||||
self.input_tensor_meta = input_tensor_meta
|
self.input_tensor_meta = input_tensor_meta
|
||||||
|
|
||||||
if isinstance(output_tensor_meta, (tuple, list)):
|
if isinstance(output_tensor_meta, (tuple, list)):
|
||||||
assert len(output_tensor_meta) == 1
|
if len(output_tensor_meta) > 1:
|
||||||
|
# Each output with same meta for Grouped GEMM
|
||||||
|
assert all(
|
||||||
|
getattr(output_tensor_meta[0], attr) == getattr(x, attr)
|
||||||
|
for x in output_tensor_meta
|
||||||
|
for attr in ["device", "dtype", "sizes", "strides", "offset"]
|
||||||
|
)
|
||||||
output_tensor_meta = output_tensor_meta[0]
|
output_tensor_meta = output_tensor_meta[0]
|
||||||
self.output_tensor_meta = output_tensor_meta
|
self.output_tensor_meta = output_tensor_meta
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ from ..loop_body import LoopBody
|
||||||
from ..scheduler import (
|
from ..scheduler import (
|
||||||
BaseSchedulerNode,
|
BaseSchedulerNode,
|
||||||
BaseScheduling,
|
BaseScheduling,
|
||||||
|
ExternKernelSchedulerNode,
|
||||||
ForeachKernelSchedulerNode,
|
ForeachKernelSchedulerNode,
|
||||||
FusedSchedulerNode,
|
FusedSchedulerNode,
|
||||||
Scheduler,
|
Scheduler,
|
||||||
|
|
@ -4905,7 +4906,24 @@ class CppScheduling(BaseScheduling):
|
||||||
epilogue_nodes=epilogue_ir_nodes,
|
epilogue_nodes=epilogue_ir_nodes,
|
||||||
)
|
)
|
||||||
with kernel:
|
with kernel:
|
||||||
for node in [template_node, *epilogue_nodes]:
|
if isinstance(template_node.node, ir.CppTemplateBuffer) and isinstance(
|
||||||
|
template_node.node.layout, ir.MultiOutputLayout
|
||||||
|
):
|
||||||
|
# For Grouped GEMM, allocate buffers for each GEMM
|
||||||
|
assert (
|
||||||
|
len(template_node.outputs) == 1
|
||||||
|
), "Grouped GEMM has 1 output template buffer"
|
||||||
|
for user in template_node.outputs[0].users:
|
||||||
|
assert isinstance(
|
||||||
|
user.node, ExternKernelSchedulerNode
|
||||||
|
), "Grouped GEMM should be with ExternKernelSchedulerNode"
|
||||||
|
assert isinstance(
|
||||||
|
user.node.node, ir.MultiOutput
|
||||||
|
), "Grouped GEMM has multi users with MultiOutput"
|
||||||
|
user.node.mark_run()
|
||||||
|
else:
|
||||||
|
template_node.mark_run() # type: ignore[attr-defined]
|
||||||
|
for node in epilogue_nodes:
|
||||||
node.mark_run() # type: ignore[attr-defined]
|
node.mark_run() # type: ignore[attr-defined]
|
||||||
src_code = render()
|
src_code = render()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import contextlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -299,9 +299,10 @@ def get_padded_n(n, block_n):
|
||||||
return (n + block_n - 1) // block_n * block_n
|
return (n + block_n - 1) // block_n * block_n
|
||||||
|
|
||||||
|
|
||||||
def transpose_w(
|
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
||||||
W: Union[ir.IRNode, torch.Tensor], trans_w: bool
|
|
||||||
) -> Union[ir.IRNode, torch.Tensor]:
|
|
||||||
|
def transpose_w(W: _T, trans_w: bool) -> _T:
|
||||||
"""
|
"""
|
||||||
Transpose W based on the trans_w flag.
|
Transpose W based on the trans_w flag.
|
||||||
"""
|
"""
|
||||||
|
|
@ -317,9 +318,7 @@ def transpose_w(
|
||||||
return W
|
return W
|
||||||
|
|
||||||
|
|
||||||
def expand_bias(
|
def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
|
||||||
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
|
|
||||||
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
|
|
||||||
"""
|
"""
|
||||||
Expand Bias to the same size of X.
|
Expand Bias to the same size of X.
|
||||||
"""
|
"""
|
||||||
|
|
@ -336,7 +335,7 @@ def expand_bias(
|
||||||
return B
|
return B
|
||||||
|
|
||||||
|
|
||||||
def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.TensorBox]):
|
def prune_tensors(input_nodes: List[ir.IRNode], new_input_nodes: List[ir.IRNode]):
|
||||||
"""
|
"""
|
||||||
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
|
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
|
||||||
"""
|
"""
|
||||||
|
|
@ -798,6 +797,7 @@ class CppGemmTemplate(CppTemplate):
|
||||||
trans_w=False,
|
trans_w=False,
|
||||||
input_indices=None,
|
input_indices=None,
|
||||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||||
|
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
||||||
):
|
):
|
||||||
if input_indices is None:
|
if input_indices is None:
|
||||||
input_indices = list(range(len(input_nodes)))
|
input_indices = list(range(len(input_nodes)))
|
||||||
|
|
@ -1251,6 +1251,7 @@ class CppGemmTemplate(CppTemplate):
|
||||||
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
||||||
# Y
|
# Y
|
||||||
if epilogue_creators:
|
if epilogue_creators:
|
||||||
|
assert isinstance(template_buffer, ir.IRNode)
|
||||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
||||||
gemm_output_buffer = ir.Buffer(
|
gemm_output_buffer = ir.Buffer(
|
||||||
name=gemm_output_name, layout=template_buffer.layout
|
name=gemm_output_name, layout=template_buffer.layout
|
||||||
|
|
@ -1276,14 +1277,17 @@ class CppGemmTemplate(CppTemplate):
|
||||||
name=buffer_name, layout=template_buffer.layout
|
name=buffer_name, layout=template_buffer.layout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
|
||||||
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
||||||
|
|
||||||
if epilogue_nodes:
|
if epilogue_nodes:
|
||||||
if not template_buffer_has_other_users:
|
if not template_buffer_has_other_users:
|
||||||
|
assert isinstance(template_buffer, ir.IRNode)
|
||||||
Y_aliases.add(template_buffer.get_name())
|
Y_aliases.add(template_buffer.get_name())
|
||||||
epilogues.extend(epilogue_nodes)
|
epilogues.extend(epilogue_nodes)
|
||||||
assert Y.get_numel() == epilogues[-1].get_numel()
|
assert Y.get_numel() == epilogues[-1].get_numel()
|
||||||
Y = cast(ir.Buffer, epilogues[-1])
|
Y = cast(ir.Buffer, epilogues[-1])
|
||||||
|
assert isinstance(template_buffer, ir.Buffer)
|
||||||
Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
|
Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
|
||||||
Y,
|
Y,
|
||||||
template_buffer,
|
template_buffer,
|
||||||
|
|
|
||||||
463
torch/_inductor/codegen/cpp_grouped_gemm_template.py
Normal file
463
torch/_inductor/codegen/cpp_grouped_gemm_template.py
Normal file
|
|
@ -0,0 +1,463 @@
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional, TypeVar
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils
|
||||||
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
from ..._dynamo.utils import counters
|
||||||
|
from .. import config, ir
|
||||||
|
from ..kernel.mm_common import mm_args
|
||||||
|
from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
|
||||||
|
from ..utils import parallel_num_threads
|
||||||
|
from ..virtualized import V
|
||||||
|
from .cpp import get_export_declaration
|
||||||
|
from .cpp_gemm_template import CppGemmTemplate, expand_bias, prune_tensors, transpose_w
|
||||||
|
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
|
||||||
|
from .cpp_template_kernel import CppTemplateKernel
|
||||||
|
from .cpp_utils import (
|
||||||
|
DTYPE_TO_CPP,
|
||||||
|
GemmBlocking,
|
||||||
|
get_gemm_template_output_and_compute_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
GEMM_TEMPLATE = r"""
|
||||||
|
{{template.header().getvalue()}}
|
||||||
|
{{micro_gemm.codegen_define(kernel)}}
|
||||||
|
|
||||||
|
extern "C" {{export_declaration}}
|
||||||
|
{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
|
||||||
|
{
|
||||||
|
{{kernel.maybe_codegen_profile()}}
|
||||||
|
{{ template.codegen_blocks(
|
||||||
|
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
|
||||||
|
) }}
|
||||||
|
{%- if num_threads > 1 %}
|
||||||
|
#pragma omp parallel num_threads({{num_threads}})
|
||||||
|
{
|
||||||
|
{{ template.codegen_multi_threads_params()|indent(8, false) }}
|
||||||
|
{%- else %}
|
||||||
|
{
|
||||||
|
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
|
||||||
|
{%- endif %}
|
||||||
|
{{ micro_gemm.codegen_init(kernel) }}
|
||||||
|
{%- set acc_buf_name_list=[] %}
|
||||||
|
{%- set acc_buf_name_prefix = "local_acc_buf_" %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
|
||||||
|
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
||||||
|
{%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
|
||||||
|
{%- endfor %}
|
||||||
|
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
|
||||||
|
{{ template.codegen_m_loop_params()|indent(12, false) }}
|
||||||
|
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
||||||
|
{{ template.codegen_n_loop_params()|indent(16, false) }}
|
||||||
|
{%- set acc_list=[] %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
|
||||||
|
{{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
|
||||||
|
{%- endfor %}
|
||||||
|
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
||||||
|
int64_t k_start = kc * Kr;
|
||||||
|
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
||||||
|
{%- set tile_X_list=[] %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
|
||||||
|
{%- endfor %}
|
||||||
|
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
||||||
|
{%- set tile_W_3d_list=[] %}
|
||||||
|
{%- set tile_W_list=[] %}
|
||||||
|
{%- set acc_slice_list=[] %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set acc_slice_list = acc_slice_list.append(
|
||||||
|
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
|
||||||
|
) %}
|
||||||
|
{%- set tile_W_3d_list = tile_W_3d_list.append(
|
||||||
|
kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
|
||||||
|
) %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set tile_W_list = tile_W_list.append(
|
||||||
|
kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
|
||||||
|
) %}
|
||||||
|
{%- endfor %}
|
||||||
|
if (kc == k_block_start) {
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{{ micro_gemm.codegen_call(
|
||||||
|
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
|
||||||
|
)|indent(28, false) }}
|
||||||
|
{%- endfor %}
|
||||||
|
} else {
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{{ micro_gemm.codegen_call(
|
||||||
|
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
|
||||||
|
)|indent(28, false) }}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
{%- set tile_acc_list = [] %}
|
||||||
|
{%- set tile_Y_list = [] %}
|
||||||
|
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||||
|
{%- set tile_acc_list = tile_acc_list.append(
|
||||||
|
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
|
||||||
|
) %}
|
||||||
|
{%- set tile_Y_list = tile_Y_list.append(
|
||||||
|
kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
|
||||||
|
) %}
|
||||||
|
{%- endfor %}
|
||||||
|
{{ kernel.store_outputs(
|
||||||
|
tile_Y_list, tile_acc_list, GemmOuts, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
||||||
|
)|indent(20, false)
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{{ micro_gemm.codegen_finalize(kernel) }}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> List[ir.IRNode]:
|
||||||
|
act_deduplicated = []
|
||||||
|
act_deduplicated_name: OrderedSet[str] = OrderedSet()
|
||||||
|
for act_idx in range(len(act_mapping.values())):
|
||||||
|
act = act_mapping[act_idx]
|
||||||
|
if act.get_name() not in act_deduplicated_name:
|
||||||
|
act_deduplicated.append(act)
|
||||||
|
act_deduplicated_name.add(act.get_name())
|
||||||
|
return act_deduplicated
|
||||||
|
|
||||||
|
|
||||||
|
class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_nodes: List[ir.IRNode],
|
||||||
|
layout: ir.Layout,
|
||||||
|
num_threads: int,
|
||||||
|
register_blocking: GemmBlocking,
|
||||||
|
beta: int = 1,
|
||||||
|
alpha: int = 1,
|
||||||
|
has_bias: bool = False,
|
||||||
|
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||||
|
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
||||||
|
gemm_grouped_num: int = 1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Template for Group of GEMMs:
|
||||||
|
* Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
|
||||||
|
for their A, B, and C matrices.
|
||||||
|
* Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
|
||||||
|
* In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
|
||||||
|
This behavior can be extended in the future if needed.
|
||||||
|
"""
|
||||||
|
super().__init__(
|
||||||
|
input_nodes,
|
||||||
|
layout,
|
||||||
|
num_threads,
|
||||||
|
register_blocking,
|
||||||
|
beta,
|
||||||
|
alpha,
|
||||||
|
has_bias,
|
||||||
|
epilogue_creator,
|
||||||
|
)
|
||||||
|
self.act_mapping = act_mapping
|
||||||
|
self.gemm_grouped_num = gemm_grouped_num
|
||||||
|
self.output_node: List[ir.Buffer] = [
|
||||||
|
ir.Buffer(name="buf_out" + str(idx), layout=layout)
|
||||||
|
for idx in range(gemm_grouped_num)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fake_get_dtype(fake_outs: List[ir.Buffer]) -> Callable[[str], torch.dtype]:
|
||||||
|
_get_dtype_real = V.graph.get_dtype
|
||||||
|
|
||||||
|
def get_dtype(name: str) -> torch.dtype:
|
||||||
|
for fake_out in fake_outs:
|
||||||
|
if name == fake_out.get_name():
|
||||||
|
return fake_out.get_dtype()
|
||||||
|
return _get_dtype_real(name)
|
||||||
|
|
||||||
|
return get_dtype
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_choices(
|
||||||
|
cls,
|
||||||
|
choices: List[ChoiceCaller],
|
||||||
|
layout: ir.Layout,
|
||||||
|
input_nodes: List[ir.IRNode],
|
||||||
|
beta: int = 1,
|
||||||
|
alpha: int = 1,
|
||||||
|
has_bias: tuple[bool, ...] = (False, False),
|
||||||
|
trans_w: bool = False,
|
||||||
|
input_indices: Optional[List[int]] = None,
|
||||||
|
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||||
|
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
|
||||||
|
) -> DataProcessorTemplateWrapper:
|
||||||
|
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
|
||||||
|
gemm_grouped_num = len(has_bias)
|
||||||
|
assert act_mapping
|
||||||
|
act_deduplicated = get_deduplicated_act(act_mapping)
|
||||||
|
wgt_start_idx = len(act_deduplicated)
|
||||||
|
bias_start_idx = wgt_start_idx + gemm_grouped_num
|
||||||
|
input_indices = list(range(len(input_nodes)))
|
||||||
|
|
||||||
|
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
||||||
|
_U = TypeVar("_U", ir.Layout, torch.Tensor)
|
||||||
|
|
||||||
|
def reorder_and_filter(
|
||||||
|
inputs: List[_T],
|
||||||
|
layout_or_out: _U,
|
||||||
|
) -> tuple[List[_T], _U]:
|
||||||
|
assert input_indices is not None, "input_indices must be set"
|
||||||
|
return [inputs[idx] for idx in input_indices], layout_or_out
|
||||||
|
|
||||||
|
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
|
||||||
|
|
||||||
|
def maybe_to_dense(
|
||||||
|
inputs: List[_T],
|
||||||
|
layout_or_out: _U,
|
||||||
|
) -> tuple[List[_T], _U]:
|
||||||
|
new_inputs = list(inputs)
|
||||||
|
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||||
|
if isinstance(inputs[idx], torch.Tensor):
|
||||||
|
W = inputs[idx]
|
||||||
|
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
|
||||||
|
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
|
||||||
|
return new_inputs, layout_or_out
|
||||||
|
|
||||||
|
def normalize_shapes(
|
||||||
|
inputs: List[_T],
|
||||||
|
layout_or_out: _U,
|
||||||
|
) -> tuple[List[_T], _U]:
|
||||||
|
new_inputs: List[_T] = list(inputs)
|
||||||
|
if not trans_w:
|
||||||
|
return new_inputs, layout_or_out
|
||||||
|
X = new_inputs[0]
|
||||||
|
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||||
|
new_input = new_inputs[wgt_idx]
|
||||||
|
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
|
||||||
|
for bias_idx in range(bias_start_idx, len(new_inputs)):
|
||||||
|
new_bias = expand_bias(new_inputs[bias_idx], X)
|
||||||
|
assert new_bias is not None
|
||||||
|
new_inputs[bias_idx] = new_bias
|
||||||
|
return new_inputs, layout_or_out
|
||||||
|
|
||||||
|
num_threads = parallel_num_threads()
|
||||||
|
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
|
||||||
|
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
|
||||||
|
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
||||||
|
new_inputs[0].get_dtype()
|
||||||
|
)
|
||||||
|
micro_gemm = create_micro_gemm(
|
||||||
|
"micro_gemm",
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
input_dtype=new_inputs[0].get_dtype(),
|
||||||
|
input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
alpha=alpha,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
assert micro_gemm is not None
|
||||||
|
_, block_n, _ = micro_gemm.register_blocking
|
||||||
|
new_size, padded_n = cls.get_padded_size(
|
||||||
|
n, block_n, k, should_block_weight=True
|
||||||
|
)
|
||||||
|
padding = padded_n - n
|
||||||
|
|
||||||
|
def pack_weight(
|
||||||
|
inputs: List[_T],
|
||||||
|
layout_or_out: _U,
|
||||||
|
) -> tuple[List[_T], _U]:
|
||||||
|
new_W_list = []
|
||||||
|
new_inputs = list(inputs)
|
||||||
|
W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
|
||||||
|
for W in W_list:
|
||||||
|
blocked_w = cls.block_weight(W, new_size, padding)
|
||||||
|
new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
|
||||||
|
new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
|
||||||
|
return new_inputs, layout_or_out
|
||||||
|
|
||||||
|
def preprocessor(
|
||||||
|
inputs: List[_T],
|
||||||
|
layout: _U,
|
||||||
|
) -> tuple[List[_T], _U]:
|
||||||
|
return pack_weight(
|
||||||
|
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
|
||||||
|
)
|
||||||
|
|
||||||
|
def postprocessor(output: _T) -> _T:
|
||||||
|
if isinstance(output, ir.TensorBox):
|
||||||
|
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
||||||
|
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
||||||
|
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
||||||
|
W_nodes = new_input_nodes[
|
||||||
|
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
||||||
|
]
|
||||||
|
W_tensor = []
|
||||||
|
for W_node in W_nodes:
|
||||||
|
assert W_node.get_name() in V.graph.constants
|
||||||
|
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||||
|
new_input_nodes[
|
||||||
|
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
||||||
|
] = W_tensor # type: ignore[assignment]
|
||||||
|
new_input_nodes, _ = pack_weight(
|
||||||
|
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
||||||
|
)
|
||||||
|
# Prune unused tensors
|
||||||
|
prune_tensors(input_nodes, new_input_nodes)
|
||||||
|
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||||
|
W_packed = new_input_nodes[idx]
|
||||||
|
assert isinstance(W_packed, torch.Tensor)
|
||||||
|
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
||||||
|
template_buffer.inputs[
|
||||||
|
idx
|
||||||
|
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||||
|
return output
|
||||||
|
|
||||||
|
template = DataProcessorTemplateWrapper(
|
||||||
|
CppGroupedGemmTemplate,
|
||||||
|
preprocessor,
|
||||||
|
postprocessor,
|
||||||
|
input_nodes=input_nodes,
|
||||||
|
layout=layout,
|
||||||
|
num_threads=num_threads,
|
||||||
|
register_blocking=micro_gemm.register_blocking,
|
||||||
|
beta=beta,
|
||||||
|
alpha=alpha,
|
||||||
|
has_bias=has_bias,
|
||||||
|
epilogue_creator=epilogue_creator,
|
||||||
|
act_mapping=act_mapping,
|
||||||
|
gemm_grouped_num=gemm_grouped_num,
|
||||||
|
)
|
||||||
|
template.maybe_append_choice(choices)
|
||||||
|
return template
|
||||||
|
|
||||||
|
def render( # type: ignore[override,return,no-untyped-def]
|
||||||
|
self,
|
||||||
|
kernel: CppTemplateKernel,
|
||||||
|
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
||||||
|
flag_template_buffer_has_other_users: Optional[bool] = None,
|
||||||
|
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
assert self.act_mapping
|
||||||
|
act_deduplicated = get_deduplicated_act(self.act_mapping)
|
||||||
|
wgt_start_idx = len(act_deduplicated)
|
||||||
|
bias_start_idx = wgt_start_idx + self.gemm_grouped_num
|
||||||
|
X_list = list(self.act_mapping.values())
|
||||||
|
W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
|
||||||
|
inp_list = []
|
||||||
|
cur_idx = bias_start_idx
|
||||||
|
for inp_idx in range(self.gemm_grouped_num):
|
||||||
|
inp = None
|
||||||
|
if self.has_bias[inp_idx]:
|
||||||
|
inp = self.input_nodes[cur_idx]
|
||||||
|
cur_idx += 1
|
||||||
|
inp_list.append(inp)
|
||||||
|
|
||||||
|
Y_list = self.output_node
|
||||||
|
if template_buffer_node is not None:
|
||||||
|
W_list = template_buffer_node.inputs[
|
||||||
|
wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
|
||||||
|
]
|
||||||
|
assert isinstance(template_buffer_node.outputs, List)
|
||||||
|
Y_list = template_buffer_node.outputs
|
||||||
|
counters["inductor"]["cpp_grouped_gemm_template"] += 1
|
||||||
|
|
||||||
|
template_buffer = Y_list[0]
|
||||||
|
fake_buffers: List[ir.Buffer] = []
|
||||||
|
Y_2d_list = Y_list
|
||||||
|
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
||||||
|
X_list[0].get_dtype()
|
||||||
|
)
|
||||||
|
micro_gemm = create_micro_gemm(
|
||||||
|
f"{kernel.kernel_name}_micro_gemm",
|
||||||
|
self.m,
|
||||||
|
self.n,
|
||||||
|
self.k,
|
||||||
|
input_dtype=X_list[0].get_dtype(),
|
||||||
|
input2_dtype=W_list[0].get_dtype(),
|
||||||
|
output_dtype=output_dtype,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
alpha=self.alpha,
|
||||||
|
num_threads=self.num_threads,
|
||||||
|
)
|
||||||
|
assert micro_gemm is not None
|
||||||
|
assert self.register_blocking == micro_gemm.register_blocking
|
||||||
|
self.log_blockings()
|
||||||
|
if isinstance(micro_gemm, CppMicroGemmAMX):
|
||||||
|
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
||||||
|
|
||||||
|
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
||||||
|
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||||
|
|
||||||
|
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
||||||
|
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||||
|
|
||||||
|
epilogues: List[ir.IRNode] = []
|
||||||
|
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
|
||||||
|
gemm_output_buffers: list[ir.Buffer] = []
|
||||||
|
for out_buf_idx in range(self.gemm_grouped_num):
|
||||||
|
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
|
||||||
|
out_buf_idx
|
||||||
|
)
|
||||||
|
gemm_output_buffers.append(
|
||||||
|
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
not self.epilogue_creator and not epilogue_nodes
|
||||||
|
), "Epilogue fusion is not implemented yet in Grouped GEMM Template"
|
||||||
|
|
||||||
|
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
||||||
|
for x_idx in range(wgt_start_idx):
|
||||||
|
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
|
||||||
|
for w_idx in range(self.gemm_grouped_num):
|
||||||
|
kernel_args["W" + str(w_idx)] = W_list[w_idx]
|
||||||
|
for inp_idx in range(self.gemm_grouped_num):
|
||||||
|
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
|
||||||
|
|
||||||
|
options = dict(
|
||||||
|
N=self.n,
|
||||||
|
K=self.k,
|
||||||
|
PADDED_N=self.padded_n,
|
||||||
|
aliases={},
|
||||||
|
beta=self.beta,
|
||||||
|
alpha=self.alpha,
|
||||||
|
num_threads=self.num_threads,
|
||||||
|
micro_gemm=micro_gemm,
|
||||||
|
is_dynamic_M=self.is_dynamic_M,
|
||||||
|
template=self,
|
||||||
|
kernel=kernel,
|
||||||
|
export_declaration=get_export_declaration(),
|
||||||
|
acc_buf_dtype=torch.float,
|
||||||
|
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
||||||
|
L1_cache_size=L1_cache_size,
|
||||||
|
L2_cache_size=L2_cache_size,
|
||||||
|
config=config,
|
||||||
|
epilogue_nodes=epilogues,
|
||||||
|
GemmOuts=gemm_output_buffers,
|
||||||
|
reindexers=reindexers,
|
||||||
|
kernel_args=kernel_args,
|
||||||
|
X_list=X_list,
|
||||||
|
W_list=W_list,
|
||||||
|
gemm_grouped_num=self.gemm_grouped_num,
|
||||||
|
Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
|
||||||
|
Y_2d_list=Y_2d_list,
|
||||||
|
)
|
||||||
|
with contextlib.ExitStack() as stack:
|
||||||
|
stack.enter_context(
|
||||||
|
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
|
||||||
|
)
|
||||||
|
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
||||||
|
|
@ -4,7 +4,7 @@ import functools
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Iterable, List, Optional, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -33,7 +33,9 @@ class CppTemplate(KernelTemplate):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
self.input_nodes = input_nodes
|
self.input_nodes = input_nodes
|
||||||
self.output_node: ir.Buffer = ir.Buffer(name="buf_out", layout=layout)
|
self.output_node: Union[ir.Buffer, List[ir.Buffer]] = ir.Buffer(
|
||||||
|
name="buf_out", layout=layout
|
||||||
|
)
|
||||||
self.layout = layout
|
self.layout = layout
|
||||||
self.num_threads = num_threads
|
self.num_threads = num_threads
|
||||||
self.epilogue_creator = epilogue_creator
|
self.epilogue_creator = epilogue_creator
|
||||||
|
|
@ -57,7 +59,10 @@ class CppTemplate(KernelTemplate):
|
||||||
expected_args = list(
|
expected_args = list(
|
||||||
unique(input_node.get_name() for input_node in self.input_nodes)
|
unique(input_node.get_name() for input_node in self.input_nodes)
|
||||||
)
|
)
|
||||||
expected_args.extend([self.output_node.get_name()])
|
if isinstance(self.output_node, Iterable):
|
||||||
|
expected_args.extend([node.get_name() for node in self.output_node])
|
||||||
|
else:
|
||||||
|
expected_args.extend([self.output_node.get_name()])
|
||||||
assert list(call_args)[: len(expected_args)] == expected_args, (
|
assert list(call_args)[: len(expected_args)] == expected_args, (
|
||||||
call_args,
|
call_args,
|
||||||
expected_args,
|
expected_args,
|
||||||
|
|
@ -102,7 +107,9 @@ class CppTemplate(KernelTemplate):
|
||||||
kernel_hash_name,
|
kernel_hash_name,
|
||||||
self.name,
|
self.name,
|
||||||
self.input_nodes,
|
self.input_nodes,
|
||||||
self.output_node.get_layout(),
|
self.output_node[0].get_layout()
|
||||||
|
if isinstance(self.output_node, Iterable)
|
||||||
|
else self.output_node.get_layout(),
|
||||||
make_kernel_render,
|
make_kernel_render,
|
||||||
bmreq,
|
bmreq,
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy.parsing.sympy_parser import parse_expr
|
from sympy.parsing.sympy_parser import parse_expr
|
||||||
|
|
@ -278,6 +278,82 @@ class CppTemplateKernel(CppKernel):
|
||||||
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
||||||
return kernel_group.loops_code.getvalue()
|
return kernel_group.loops_code.getvalue()
|
||||||
|
|
||||||
|
def store_grouped_gemm_pointwise_nodes(
|
||||||
|
self,
|
||||||
|
dst: tuple[ir.Buffer],
|
||||||
|
nodes: List[List[ir.IRNode]],
|
||||||
|
offsets: Optional[List[sympy.Expr]] = None,
|
||||||
|
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
||||||
|
) -> str:
|
||||||
|
ref_dst = dst[0]
|
||||||
|
var_sizes = (tuple(ref_dst.get_size()), ())
|
||||||
|
var_ranges = {
|
||||||
|
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
||||||
|
for i, sz in enumerate(var_sizes[0])
|
||||||
|
}
|
||||||
|
if not offsets:
|
||||||
|
offsets = [sympy.S.Zero] * len(var_sizes[0])
|
||||||
|
if not reindexers:
|
||||||
|
reindexers = [None] * len(nodes)
|
||||||
|
assert len(offsets) == len(var_sizes[0])
|
||||||
|
output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
|
||||||
|
kernel_group = KernelGroup()
|
||||||
|
kernel_group.args = self.args
|
||||||
|
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||||
|
bodies = []
|
||||||
|
var_sizes_list = []
|
||||||
|
assert isinstance(nodes[0], Iterable)
|
||||||
|
grouped_gemm_number = len(nodes)
|
||||||
|
epilogue_nodes = nodes[0]
|
||||||
|
assert isinstance(epilogue_nodes, Iterable)
|
||||||
|
for i, _ in enumerate(epilogue_nodes):
|
||||||
|
output_names = []
|
||||||
|
gemm_nodes = []
|
||||||
|
for gemm_idx in range(grouped_gemm_number):
|
||||||
|
single_gemm_nodes = nodes[gemm_idx]
|
||||||
|
assert isinstance(dst, Iterable)
|
||||||
|
single_gemm_dst = dst[gemm_idx]
|
||||||
|
assert isinstance(single_gemm_nodes, Iterable)
|
||||||
|
assert isinstance(single_gemm_dst, ir.IRNode)
|
||||||
|
gemm_nodes.append(single_gemm_nodes[i])
|
||||||
|
output_names.append(
|
||||||
|
single_gemm_nodes[i].get_name()
|
||||||
|
if i < len(single_gemm_nodes) - 1
|
||||||
|
else single_gemm_dst.get_name()
|
||||||
|
)
|
||||||
|
_node = gemm_nodes[gemm_idx]
|
||||||
|
gemm_nodes[gemm_idx] = (
|
||||||
|
_node.data if isinstance(_node, ir.ComputedBuffer) else _node
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn(*args):
|
||||||
|
assert len(args) == 2
|
||||||
|
assert len(args[0]) == len(var_sizes[0])
|
||||||
|
assert len(args[1]) == 0
|
||||||
|
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
|
||||||
|
if reindexers[i] is not None:
|
||||||
|
new_args = reindexers[i](new_args) # type: ignore[misc]
|
||||||
|
for gemm_idx in range(grouped_gemm_number):
|
||||||
|
V.ops.store(
|
||||||
|
output_names[gemm_idx],
|
||||||
|
output_index,
|
||||||
|
gemm_nodes[gemm_idx].make_loader()(new_args).value,
|
||||||
|
)
|
||||||
|
|
||||||
|
body = LoopBody(
|
||||||
|
fn,
|
||||||
|
(list(var_ranges.keys()), ()),
|
||||||
|
var_ranges,
|
||||||
|
list(var_ranges.keys()),
|
||||||
|
tuple(),
|
||||||
|
)
|
||||||
|
bodies.append(body)
|
||||||
|
var_sizes_list.append(var_sizes)
|
||||||
|
|
||||||
|
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
||||||
|
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
||||||
|
return kernel_group.loops_code.getvalue()
|
||||||
|
|
||||||
def store_output(
|
def store_output(
|
||||||
self,
|
self,
|
||||||
dst: ir.Buffer,
|
dst: ir.Buffer,
|
||||||
|
|
@ -335,6 +411,43 @@ class CppTemplateKernel(CppKernel):
|
||||||
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def store_outputs(
|
||||||
|
self,
|
||||||
|
dst: tuple[ir.Buffer],
|
||||||
|
src: tuple[ir.IRNode],
|
||||||
|
orig_src: Optional[tuple[ir.IRNode]] = None,
|
||||||
|
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
||||||
|
offsets: Optional[List[Any]] = None,
|
||||||
|
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
||||||
|
):
|
||||||
|
# Grouped GEMM may have multi outputs to be localized
|
||||||
|
assert isinstance(src, Iterable)
|
||||||
|
assert isinstance(dst, Iterable)
|
||||||
|
assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
|
||||||
|
if offsets:
|
||||||
|
offsets = parse_expr_with_index_symbols(offsets)
|
||||||
|
if epilogue_nodes:
|
||||||
|
assert (
|
||||||
|
not epilogue_nodes
|
||||||
|
), "epilogue_nodes not supported for Grouped GEMM yet"
|
||||||
|
else:
|
||||||
|
if dst[0].get_name() != src[0].get_name():
|
||||||
|
copy_list = []
|
||||||
|
with LocalBufferContext(self.args) as scope:
|
||||||
|
for _src, _dst in zip(src, dst):
|
||||||
|
copy_list.append([L.copy(_dst, _src).data.data])
|
||||||
|
scope.add_local_buffer(_src)
|
||||||
|
return self.store_grouped_gemm_pointwise_nodes(dst, copy_list)
|
||||||
|
else:
|
||||||
|
assert all(
|
||||||
|
_src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
|
||||||
|
)
|
||||||
|
assert all(
|
||||||
|
_src.get_layout() == _dst.get_layout()
|
||||||
|
for _src, _dst in zip(src, dst)
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
class CppTemplateCaller(ir.ChoiceCaller):
|
class CppTemplateCaller(ir.ChoiceCaller):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -2193,9 +2193,12 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
self.allocated.add(name)
|
self.allocated.add(name)
|
||||||
if isinstance(
|
if (
|
||||||
buffer.get_defining_op(),
|
isinstance(
|
||||||
(ir.ExternKernelAlloc, ir.MultiOutput),
|
buffer.get_defining_op(),
|
||||||
|
(ir.ExternKernelAlloc, ir.MultiOutput),
|
||||||
|
)
|
||||||
|
and not buffer.should_allocate()
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -884,6 +884,9 @@ class cpp:
|
||||||
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
|
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Enable the Grouped GEMM Fusion
|
||||||
|
enable_grouped_gemm_template = False
|
||||||
|
|
||||||
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
|
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
|
||||||
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread
|
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread
|
||||||
# synchronization and buffers, the maximal number of slices is limited to
|
# synchronization and buffers, the maximal number of slices is limited to
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,75 @@ if torch._C._has_mkldnn:
|
||||||
_linear_args = [Arg() for _ in range(6)]
|
_linear_args = [Arg() for _ in range(6)]
|
||||||
_conv_transpose_args = [Arg() for _ in range(11)]
|
_conv_transpose_args = [Arg() for _ in range(11)]
|
||||||
|
|
||||||
|
def _is_valid_grouped_gemm_fusion(computation_nodes):
|
||||||
|
"""
|
||||||
|
Here we check:
|
||||||
|
1. More than 1 GEMM nodes has been found.
|
||||||
|
2. All the GEMM nodes share the same activation.
|
||||||
|
3. All the GEMM nodes have same weight size but different wgt node.
|
||||||
|
"""
|
||||||
|
computation_op = mkldnn._linear_pointwise.default
|
||||||
|
act = computation_nodes[0].args[0]
|
||||||
|
wgt = computation_nodes[0].args[1]
|
||||||
|
wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr]
|
||||||
|
return len(computation_nodes) >= 2 and all(
|
||||||
|
(
|
||||||
|
node.target == computation_op
|
||||||
|
and node.args[0] == act
|
||||||
|
and (node.args[1].meta.get("val").size() == wgt_size)
|
||||||
|
and (node.args[1] != wgt or gemm_idx == 0)
|
||||||
|
and not node.args[2] # <TODO> support bias through epilogue fusion
|
||||||
|
)
|
||||||
|
for gemm_idx, node in enumerate(computation_nodes)
|
||||||
|
)
|
||||||
|
|
||||||
|
def grouped_gemm_pass(graph: torch.fx.Graph):
|
||||||
|
"""
|
||||||
|
Group GEMM has multi output nodes which is compilicated to define a Pattern.
|
||||||
|
Use below way to connect the pattern to the lowering.
|
||||||
|
TODO: Use MultiOutputPattern, current limitation is the pattern requires
|
||||||
|
fixed number of output nodes. Extend to support Group GEMM for pattern matcher.
|
||||||
|
"""
|
||||||
|
computation_op = mkldnn._linear_pointwise.default
|
||||||
|
from ..mkldnn_lowerings import grouped_gemm_lowering
|
||||||
|
|
||||||
|
for node in graph.find_nodes(op="call_function", target=computation_op):
|
||||||
|
if (
|
||||||
|
not node._erased
|
||||||
|
and isinstance(node.meta.get("val"), torch.Tensor)
|
||||||
|
and node.meta["val"].device.type == "cpu"
|
||||||
|
):
|
||||||
|
act = node.args[0]
|
||||||
|
users = list(act.users)
|
||||||
|
if _is_valid_grouped_gemm_fusion(users):
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
grouped_gemm_node = graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
grouped_gemm_lowering,
|
||||||
|
(
|
||||||
|
act,
|
||||||
|
[user.args[1] for user in users],
|
||||||
|
[None for _ in users],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
grouped_gemm_node.meta["val"] = [
|
||||||
|
user.meta["val"] for user in users
|
||||||
|
]
|
||||||
|
with graph.inserting_after(grouped_gemm_node):
|
||||||
|
for gemm_idx, user in enumerate(users):
|
||||||
|
assert user.target == computation_op
|
||||||
|
get_item = graph.create_node(
|
||||||
|
"call_function",
|
||||||
|
operator.getitem,
|
||||||
|
(
|
||||||
|
grouped_gemm_node,
|
||||||
|
gemm_idx,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
user.replace_all_uses_with(get_item)
|
||||||
|
graph.erase_node(user)
|
||||||
|
return
|
||||||
|
|
||||||
def _conv_call(users=1):
|
def _conv_call(users=1):
|
||||||
return CallFunction(
|
return CallFunction(
|
||||||
mkldnn._convolution_pointwise.default, *_conv_args, _users=users
|
mkldnn._convolution_pointwise.default, *_conv_args, _users=users
|
||||||
|
|
|
||||||
|
|
@ -102,6 +102,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||||
post_grad_custom_pre_pass
|
post_grad_custom_pre_pass
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
config.cpp.enable_grouped_gemm_template
|
||||||
|
and config.max_autotune
|
||||||
|
and "CPP" in config.max_autotune_gemm_backends
|
||||||
|
and torch._C._has_mkldnn
|
||||||
|
):
|
||||||
|
from .mkldnn_fusion import grouped_gemm_pass
|
||||||
|
|
||||||
|
grouped_gemm_pass(gm.graph)
|
||||||
|
|
||||||
if config.pattern_matcher:
|
if config.pattern_matcher:
|
||||||
lazy_init()
|
lazy_init()
|
||||||
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
||||||
|
|
|
||||||
|
|
@ -4496,6 +4496,18 @@ class CppTemplateBuffer(TemplateBuffer):
|
||||||
super().__init__(layout, inputs, make_kernel_render)
|
super().__init__(layout, inputs, make_kernel_render)
|
||||||
self.template = template
|
self.template = template
|
||||||
self.choice = choice
|
self.choice = choice
|
||||||
|
self.outputs: Optional[List[Buffer]] = None
|
||||||
|
|
||||||
|
def get_layout(self) -> Layout:
|
||||||
|
if isinstance(self.layout, MultiOutputLayout):
|
||||||
|
assert isinstance(self.outputs, Iterable)
|
||||||
|
first_output = self.outputs[0]
|
||||||
|
assert isinstance(first_output, Buffer)
|
||||||
|
layout = first_output.layout
|
||||||
|
assert isinstance(layout, Layout)
|
||||||
|
return layout
|
||||||
|
else:
|
||||||
|
return super().get_layout()
|
||||||
|
|
||||||
|
|
||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
|
|
@ -6832,6 +6844,10 @@ class MultiOutput(ExternKernel):
|
||||||
return self.inputs[0].get_unbacked_symbol_uses()
|
return self.inputs[0].get_unbacked_symbol_uses()
|
||||||
|
|
||||||
def should_allocate(self) -> bool:
|
def should_allocate(self) -> bool:
|
||||||
|
if len(self.inputs) == 1 and (
|
||||||
|
isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
|
||||||
|
):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_inputs_that_alias_output(self) -> Sequence[str]:
|
def get_inputs_that_alias_output(self) -> Sequence[str]:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from torch._inductor.kernel.mm_common import mm_args
|
||||||
|
|
||||||
from . import ir
|
from . import ir
|
||||||
from .codegen.cpp_gemm_template import CppGemmTemplate
|
from .codegen.cpp_gemm_template import CppGemmTemplate
|
||||||
|
from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
|
||||||
from .codegen.cpp_utils import create_epilogue_with_attr
|
from .codegen.cpp_utils import create_epilogue_with_attr
|
||||||
from .ir import TensorBox
|
from .ir import TensorBox
|
||||||
from .lowering import (
|
from .lowering import (
|
||||||
|
|
@ -28,6 +29,73 @@ from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotun
|
||||||
from .virtualized import ops, V
|
from .virtualized import ops, V
|
||||||
|
|
||||||
|
|
||||||
|
def grouped_gemm_lowering(
|
||||||
|
x: TensorBox,
|
||||||
|
w: List[TensorBox],
|
||||||
|
b: List[TensorBox],
|
||||||
|
attr=None,
|
||||||
|
scalars=None,
|
||||||
|
algorithm=None,
|
||||||
|
layout=None,
|
||||||
|
):
|
||||||
|
x_size = x.get_size()
|
||||||
|
if len(x_size) > 2:
|
||||||
|
# GEMM template needs 2D input, normalize input shape here
|
||||||
|
x = view(x, [-1, x_size[-1]])
|
||||||
|
num_gemm = len(w)
|
||||||
|
|
||||||
|
assert use_max_autotune()
|
||||||
|
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
|
||||||
|
|
||||||
|
choices: List[ChoiceCaller] = []
|
||||||
|
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
||||||
|
|
||||||
|
kwargs = dict(
|
||||||
|
has_bias=[bias is not None for bias in b],
|
||||||
|
trans_w=True,
|
||||||
|
epilogue_creator=None,
|
||||||
|
act_mapping={num: x for num in range(num_gemm)},
|
||||||
|
)
|
||||||
|
|
||||||
|
input_nodes = [x, *w]
|
||||||
|
input_nodes.extend([bias for bias in b if bias is not None])
|
||||||
|
|
||||||
|
CppGroupedGemmTemplate.add_choices(
|
||||||
|
choices,
|
||||||
|
layout,
|
||||||
|
input_nodes,
|
||||||
|
**kwargs, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(choices) != 0
|
||||||
|
result = autotune_select_algorithm(
|
||||||
|
"grouped_gemm",
|
||||||
|
choices,
|
||||||
|
input_nodes,
|
||||||
|
layout,
|
||||||
|
)
|
||||||
|
template_buf = result.data.data
|
||||||
|
return_bufs = [
|
||||||
|
ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
|
||||||
|
for gemm_idx in range(num_gemm)
|
||||||
|
]
|
||||||
|
template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
|
||||||
|
template_buf.outputs = return_bufs
|
||||||
|
return_tensors = [
|
||||||
|
ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
|
||||||
|
]
|
||||||
|
if len(x_size) > 2:
|
||||||
|
for gemm_idx in range(num_gemm):
|
||||||
|
return_tensors[gemm_idx] = view(
|
||||||
|
return_tensors[gemm_idx],
|
||||||
|
(*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
|
||||||
|
)
|
||||||
|
return return_tensors
|
||||||
|
|
||||||
|
|
||||||
|
grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
def register_onednn_fusion_ops():
|
def register_onednn_fusion_ops():
|
||||||
if torch._C._has_mkldnn:
|
if torch._C._has_mkldnn:
|
||||||
from . import mkldnn_ir
|
from . import mkldnn_ir
|
||||||
|
|
|
||||||
|
|
@ -3607,6 +3607,15 @@ class Scheduler:
|
||||||
# the current kernel from where 'allocate' retrieve those decisions.
|
# the current kernel from where 'allocate' retrieve those decisions.
|
||||||
# We have to make sure there is a non-NULL kernel handler to store
|
# We have to make sure there is a non-NULL kernel handler to store
|
||||||
# those inplace update decisions.
|
# those inplace update decisions.
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(scheduler_node.node, ir.MultiOutput)
|
||||||
|
and len(scheduler_node.node.inputs) == 1
|
||||||
|
and isinstance(scheduler_node.node.inputs[0], ir.CppTemplateBuffer)
|
||||||
|
):
|
||||||
|
# <TODO> Remove this code after Fuse MultiOutput and CppTemplateBuffer
|
||||||
|
return
|
||||||
|
|
||||||
counters["inductor"]["extern_calls"] += 1
|
counters["inductor"]["extern_calls"] += 1
|
||||||
with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
|
with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
|
||||||
scheduler_node.decide_inplace_update()
|
scheduler_node.decide_inplace_update()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user