[Inductor][CPP] Enable Grouped GEMM Template (#143796)

**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012

- Support flexible number of GEMMs
- Share activation across GEMMs
  - The Grouped GEMM Template supports independent activations
  - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
  - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
  - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```

**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
        self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)

    def forward(self, x):
        return self.linear0(x), self.linear1(x)

if __name__ == "__main__":
    with torch.no_grad():
        input = torch.randn(batch_size, in_features, dtype=dtype)
        m = M(bias=bias).to(dtype=dtype).eval()
        cm = torch.compile(m)
        act_res = cm(input)
```

Generated Code:  https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py

**Next Step**

- Support Epilogue fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
leslie-fang-intel 2025-01-13 00:03:44 -08:00 committed by PyTorch MergeBot
parent 35b46a75f1
commit 25de671ea8
15 changed files with 913 additions and 19 deletions

View File

@ -207,7 +207,12 @@ if RUN_CPU:
*[
BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
if func.startswith("test_linear_with_pointwise")
if func.startswith(
(
"test_linear_with_pointwise",
"test_grouped_linear",
)
)
],
BaseTest("test_polar"),
BaseTest(

View File

@ -1683,6 +1683,105 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear_invalid(
self,
batch_size,
in_features,
out_features,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature + gemm_idx, bias=False)
for gemm_idx in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
# each linear has different num of out features, thus invaild grouped gemm
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
v = torch.randn(batch_size, in_features).to(dtype)
with verify(dtype) as (atol, rtol), torch.autocast(
device_type="cpu", dtype=dtype
), torch.no_grad():
self.common(mod, (v,), atol=atol, rtol=rtol)
# gemm_num independent template instead of grouped gemm template
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"], gemm_num
)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("input_3d", (False, True))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear(
self,
batch_size,
in_features,
out_features,
input_3d,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature, bias=False)
for _ in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
if dtype == torch.float16 and input_3d:
# reduce the number of tests
continue
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype)
with verify(dtype) as (atol, rtol), torch.autocast(
device_type="cpu", dtype=dtype
), torch.no_grad():
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
@inductor_config.patch({"freezing": False})
@patches
@torch.no_grad
@ -2031,6 +2130,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
test_quantized_linear_amx_dynamic_shapes = (
TestSelectAlgorithm.test_quantized_linear_amx
)
test_grouped_linear_dynamic_shapes = TestSelectAlgorithm.test_grouped_linear
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
test_linear_cache_blocking_dynamic_shapes = (
TestSelectAlgorithm.test_linear_cache_blocking

View File

@ -500,7 +500,13 @@ class BenchmarkRequest:
self.input_tensor_meta = input_tensor_meta
if isinstance(output_tensor_meta, (tuple, list)):
assert len(output_tensor_meta) == 1
if len(output_tensor_meta) > 1:
# Each output with same meta for Grouped GEMM
assert all(
getattr(output_tensor_meta[0], attr) == getattr(x, attr)
for x in output_tensor_meta
for attr in ["device", "dtype", "sizes", "strides", "offset"]
)
output_tensor_meta = output_tensor_meta[0]
self.output_tensor_meta = output_tensor_meta

View File

@ -26,6 +26,7 @@ from ..loop_body import LoopBody
from ..scheduler import (
BaseSchedulerNode,
BaseScheduling,
ExternKernelSchedulerNode,
ForeachKernelSchedulerNode,
FusedSchedulerNode,
Scheduler,
@ -4905,7 +4906,24 @@ class CppScheduling(BaseScheduling):
epilogue_nodes=epilogue_ir_nodes,
)
with kernel:
for node in [template_node, *epilogue_nodes]:
if isinstance(template_node.node, ir.CppTemplateBuffer) and isinstance(
template_node.node.layout, ir.MultiOutputLayout
):
# For Grouped GEMM, allocate buffers for each GEMM
assert (
len(template_node.outputs) == 1
), "Grouped GEMM has 1 output template buffer"
for user in template_node.outputs[0].users:
assert isinstance(
user.node, ExternKernelSchedulerNode
), "Grouped GEMM should be with ExternKernelSchedulerNode"
assert isinstance(
user.node.node, ir.MultiOutput
), "Grouped GEMM has multi users with MultiOutput"
user.node.mark_run()
else:
template_node.mark_run() # type: ignore[attr-defined]
for node in epilogue_nodes:
node.mark_run() # type: ignore[attr-defined]
src_code = render()

View File

@ -3,7 +3,7 @@ import contextlib
import logging
import math
from functools import lru_cache
from typing import Any, Callable, cast, Dict, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
from unittest.mock import patch
import torch
@ -299,9 +299,10 @@ def get_padded_n(n, block_n):
return (n + block_n - 1) // block_n * block_n
def transpose_w(
W: Union[ir.IRNode, torch.Tensor], trans_w: bool
) -> Union[ir.IRNode, torch.Tensor]:
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
def transpose_w(W: _T, trans_w: bool) -> _T:
"""
Transpose W based on the trans_w flag.
"""
@ -317,9 +318,7 @@ def transpose_w(
return W
def expand_bias(
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
"""
Expand Bias to the same size of X.
"""
@ -336,7 +335,7 @@ def expand_bias(
return B
def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.TensorBox]):
def prune_tensors(input_nodes: List[ir.IRNode], new_input_nodes: List[ir.IRNode]):
"""
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
"""
@ -798,6 +797,7 @@ class CppGemmTemplate(CppTemplate):
trans_w=False,
input_indices=None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None,
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
@ -1251,6 +1251,7 @@ class CppGemmTemplate(CppTemplate):
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
# Y
if epilogue_creators:
assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer(
name=gemm_output_name, layout=template_buffer.layout
@ -1276,14 +1277,17 @@ class CppGemmTemplate(CppTemplate):
name=buffer_name, layout=template_buffer.layout
)
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
if epilogue_nodes:
if not template_buffer_has_other_users:
assert isinstance(template_buffer, ir.IRNode)
Y_aliases.add(template_buffer.get_name())
epilogues.extend(epilogue_nodes)
assert Y.get_numel() == epilogues[-1].get_numel()
Y = cast(ir.Buffer, epilogues[-1])
assert isinstance(template_buffer, ir.Buffer)
Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
Y,
template_buffer,

View File

@ -0,0 +1,463 @@
import contextlib
import logging
from typing import Any, Callable, List, Optional, TypeVar
from unittest.mock import patch
import torch
import torch.utils
from torch.utils._ordered_set import OrderedSet
from ..._dynamo.utils import counters
from .. import config, ir
from ..kernel.mm_common import mm_args
from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
from ..utils import parallel_num_threads
from ..virtualized import V
from .cpp import get_export_declaration
from .cpp_gemm_template import CppGemmTemplate, expand_bias, prune_tensors, transpose_w
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import (
DTYPE_TO_CPP,
GemmBlocking,
get_gemm_template_output_and_compute_dtype,
)
log = logging.getLogger(__name__)
GEMM_TEMPLATE = r"""
{{template.header().getvalue()}}
{{micro_gemm.codegen_define(kernel)}}
extern "C" {{export_declaration}}
{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
{
{{kernel.maybe_codegen_profile()}}
{{ template.codegen_blocks(
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
) }}
{%- if num_threads > 1 %}
#pragma omp parallel num_threads({{num_threads}})
{
{{ template.codegen_multi_threads_params()|indent(8, false) }}
{%- else %}
{
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
{%- endif %}
{{ micro_gemm.codegen_init(kernel) }}
{%- set acc_buf_name_list=[] %}
{%- set acc_buf_name_prefix = "local_acc_buf_" %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
{%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
{%- endfor %}
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
{{ template.codegen_m_loop_params()|indent(12, false) }}
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
{{ template.codegen_n_loop_params()|indent(16, false) }}
{%- set acc_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
{{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
{%- endfor %}
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
{%- set tile_X_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
{%- endfor %}
for (int64_t nci = nc; nci < nc_block_end; nci++) {
{%- set tile_W_3d_list=[] %}
{%- set tile_W_list=[] %}
{%- set acc_slice_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_slice_list = acc_slice_list.append(
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
) %}
{%- set tile_W_3d_list = tile_W_3d_list.append(
kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
) %}
{%- endfor %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_W_list = tile_W_list.append(
kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
) %}
{%- endfor %}
if (kc == k_block_start) {
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{{ micro_gemm.codegen_call(
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
)|indent(28, false) }}
{%- endfor %}
} else {
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{{ micro_gemm.codegen_call(
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
)|indent(28, false) }}
{%- endfor %}
}
}
}
{
{%- set tile_acc_list = [] %}
{%- set tile_Y_list = [] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_acc_list = tile_acc_list.append(
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
) %}
{%- set tile_Y_list = tile_Y_list.append(
kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
) %}
{%- endfor %}
{{ kernel.store_outputs(
tile_Y_list, tile_acc_list, GemmOuts, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
)|indent(20, false)
}}
}
}
}
{{ micro_gemm.codegen_finalize(kernel) }}
}
}
"""
def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> List[ir.IRNode]:
act_deduplicated = []
act_deduplicated_name: OrderedSet[str] = OrderedSet()
for act_idx in range(len(act_mapping.values())):
act = act_mapping[act_idx]
if act.get_name() not in act_deduplicated_name:
act_deduplicated.append(act)
act_deduplicated_name.add(act.get_name())
return act_deduplicated
class CppGroupedGemmTemplate(CppGemmTemplate):
def __init__(
self,
input_nodes: List[ir.IRNode],
layout: ir.Layout,
num_threads: int,
register_blocking: GemmBlocking,
beta: int = 1,
alpha: int = 1,
has_bias: bool = False,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None,
gemm_grouped_num: int = 1,
) -> None:
"""
Template for Group of GEMMs:
* Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
for their A, B, and C matrices.
* Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
* In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
This behavior can be extended in the future if needed.
"""
super().__init__(
input_nodes,
layout,
num_threads,
register_blocking,
beta,
alpha,
has_bias,
epilogue_creator,
)
self.act_mapping = act_mapping
self.gemm_grouped_num = gemm_grouped_num
self.output_node: List[ir.Buffer] = [
ir.Buffer(name="buf_out" + str(idx), layout=layout)
for idx in range(gemm_grouped_num)
]
@staticmethod
def _fake_get_dtype(fake_outs: List[ir.Buffer]) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
def get_dtype(name: str) -> torch.dtype:
for fake_out in fake_outs:
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
@classmethod
def add_choices(
cls,
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[ir.IRNode],
beta: int = 1,
alpha: int = 1,
has_bias: tuple[bool, ...] = (False, False),
trans_w: bool = False,
input_indices: Optional[List[int]] = None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
) -> DataProcessorTemplateWrapper:
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
gemm_grouped_num = len(has_bias)
assert act_mapping
act_deduplicated = get_deduplicated_act(act_mapping)
wgt_start_idx = len(act_deduplicated)
bias_start_idx = wgt_start_idx + gemm_grouped_num
input_indices = list(range(len(input_nodes)))
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
_U = TypeVar("_U", ir.Layout, torch.Tensor)
def reorder_and_filter(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
assert input_indices is not None, "input_indices must be set"
return [inputs[idx] for idx in input_indices], layout_or_out
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
def maybe_to_dense(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_inputs = list(inputs)
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
if isinstance(inputs[idx], torch.Tensor):
W = inputs[idx]
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out
def normalize_shapes(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_inputs: List[_T] = list(inputs)
if not trans_w:
return new_inputs, layout_or_out
X = new_inputs[0]
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
new_input = new_inputs[wgt_idx]
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
for bias_idx in range(bias_start_idx, len(new_inputs)):
new_bias = expand_bias(new_inputs[bias_idx], X)
assert new_bias is not None
new_inputs[bias_idx] = new_bias
return new_inputs, layout_or_out
num_threads = parallel_num_threads()
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
new_inputs[0].get_dtype()
)
micro_gemm = create_micro_gemm(
"micro_gemm",
m,
n,
k,
input_dtype=new_inputs[0].get_dtype(),
input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
alpha=alpha,
num_threads=num_threads,
)
assert micro_gemm is not None
_, block_n, _ = micro_gemm.register_blocking
new_size, padded_n = cls.get_padded_size(
n, block_n, k, should_block_weight=True
)
padding = padded_n - n
def pack_weight(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_W_list = []
new_inputs = list(inputs)
W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
for W in W_list:
blocked_w = cls.block_weight(W, new_size, padding)
new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
return new_inputs, layout_or_out
def preprocessor(
inputs: List[_T],
layout: _U,
) -> tuple[List[_T], _U]:
return pack_weight(
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
)
def postprocessor(output: _T) -> _T:
if isinstance(output, ir.TensorBox):
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
assert isinstance(template_buffer, ir.CppTemplateBuffer)
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
W_nodes = new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
]
W_tensor = []
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
] = W_tensor # type: ignore[assignment]
new_input_nodes, _ = pack_weight(
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
)
# Prune unused tensors
prune_tensors(input_nodes, new_input_nodes)
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
W_packed = new_input_nodes[idx]
assert isinstance(W_packed, torch.Tensor)
W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[
idx
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
return output
template = DataProcessorTemplateWrapper(
CppGroupedGemmTemplate,
preprocessor,
postprocessor,
input_nodes=input_nodes,
layout=layout,
num_threads=num_threads,
register_blocking=micro_gemm.register_blocking,
beta=beta,
alpha=alpha,
has_bias=has_bias,
epilogue_creator=epilogue_creator,
act_mapping=act_mapping,
gemm_grouped_num=gemm_grouped_num,
)
template.maybe_append_choice(choices)
return template
def render( # type: ignore[override,return,no-untyped-def]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
flag_template_buffer_has_other_users: Optional[bool] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
**kwargs,
) -> str:
assert self.act_mapping
act_deduplicated = get_deduplicated_act(self.act_mapping)
wgt_start_idx = len(act_deduplicated)
bias_start_idx = wgt_start_idx + self.gemm_grouped_num
X_list = list(self.act_mapping.values())
W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
inp_list = []
cur_idx = bias_start_idx
for inp_idx in range(self.gemm_grouped_num):
inp = None
if self.has_bias[inp_idx]:
inp = self.input_nodes[cur_idx]
cur_idx += 1
inp_list.append(inp)
Y_list = self.output_node
if template_buffer_node is not None:
W_list = template_buffer_node.inputs[
wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
]
assert isinstance(template_buffer_node.outputs, List)
Y_list = template_buffer_node.outputs
counters["inductor"]["cpp_grouped_gemm_template"] += 1
template_buffer = Y_list[0]
fake_buffers: List[ir.Buffer] = []
Y_2d_list = Y_list
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
X_list[0].get_dtype()
)
micro_gemm = create_micro_gemm(
f"{kernel.kernel_name}_micro_gemm",
self.m,
self.n,
self.k,
input_dtype=X_list[0].get_dtype(),
input2_dtype=W_list[0].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
alpha=self.alpha,
num_threads=self.num_threads,
)
assert micro_gemm is not None
assert self.register_blocking == micro_gemm.register_blocking
self.log_blockings()
if isinstance(micro_gemm, CppMicroGemmAMX):
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
epilogues: List[ir.IRNode] = []
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
gemm_output_buffers: list[ir.Buffer] = []
for out_buf_idx in range(self.gemm_grouped_num):
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
out_buf_idx
)
gemm_output_buffers.append(
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
)
assert (
not self.epilogue_creator and not epilogue_nodes
), "Epilogue fusion is not implemented yet in Grouped GEMM Template"
kernel_args: dict[str, Optional[ir.IRNode]] = {}
for x_idx in range(wgt_start_idx):
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
for w_idx in range(self.gemm_grouped_num):
kernel_args["W" + str(w_idx)] = W_list[w_idx]
for inp_idx in range(self.gemm_grouped_num):
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
options = dict(
N=self.n,
K=self.k,
PADDED_N=self.padded_n,
aliases={},
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
micro_gemm=micro_gemm,
is_dynamic_M=self.is_dynamic_M,
template=self,
kernel=kernel,
export_declaration=get_export_declaration(),
acc_buf_dtype=torch.float,
DTYPE_TO_CPP=DTYPE_TO_CPP,
L1_cache_size=L1_cache_size,
L2_cache_size=L2_cache_size,
config=config,
epilogue_nodes=epilogues,
GemmOuts=gemm_output_buffers,
reindexers=reindexers,
kernel_args=kernel_args,
X_list=X_list,
W_list=W_list,
gemm_grouped_num=self.gemm_grouped_num,
Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
Y_2d_list=Y_2d_list,
)
with contextlib.ExitStack() as stack:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)

View File

@ -4,7 +4,7 @@ import functools
import itertools
import logging
import sys
from typing import Callable, List, Optional
from typing import Callable, Iterable, List, Optional, Union
from unittest.mock import patch
import sympy
@ -33,7 +33,9 @@ class CppTemplate(KernelTemplate):
) -> None:
super().__init__(name)
self.input_nodes = input_nodes
self.output_node: ir.Buffer = ir.Buffer(name="buf_out", layout=layout)
self.output_node: Union[ir.Buffer, List[ir.Buffer]] = ir.Buffer(
name="buf_out", layout=layout
)
self.layout = layout
self.num_threads = num_threads
self.epilogue_creator = epilogue_creator
@ -57,7 +59,10 @@ class CppTemplate(KernelTemplate):
expected_args = list(
unique(input_node.get_name() for input_node in self.input_nodes)
)
expected_args.extend([self.output_node.get_name()])
if isinstance(self.output_node, Iterable):
expected_args.extend([node.get_name() for node in self.output_node])
else:
expected_args.extend([self.output_node.get_name()])
assert list(call_args)[: len(expected_args)] == expected_args, (
call_args,
expected_args,
@ -102,7 +107,9 @@ class CppTemplate(KernelTemplate):
kernel_hash_name,
self.name,
self.input_nodes,
self.output_node.get_layout(),
self.output_node[0].get_layout()
if isinstance(self.output_node, Iterable)
else self.output_node.get_layout(),
make_kernel_render,
bmreq,
self,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import itertools
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import sympy
from sympy.parsing.sympy_parser import parse_expr
@ -278,6 +278,82 @@ class CppTemplateKernel(CppKernel):
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue()
def store_grouped_gemm_pointwise_nodes(
self,
dst: tuple[ir.Buffer],
nodes: List[List[ir.IRNode]],
offsets: Optional[List[sympy.Expr]] = None,
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
) -> str:
ref_dst = dst[0]
var_sizes = (tuple(ref_dst.get_size()), ())
var_ranges = {
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
for i, sz in enumerate(var_sizes[0])
}
if not offsets:
offsets = [sympy.S.Zero] * len(var_sizes[0])
if not reindexers:
reindexers = [None] * len(nodes)
assert len(offsets) == len(var_sizes[0])
output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
kernel_group = KernelGroup()
kernel_group.args = self.args
cpp_kernel_proxy = CppKernelProxy(kernel_group)
bodies = []
var_sizes_list = []
assert isinstance(nodes[0], Iterable)
grouped_gemm_number = len(nodes)
epilogue_nodes = nodes[0]
assert isinstance(epilogue_nodes, Iterable)
for i, _ in enumerate(epilogue_nodes):
output_names = []
gemm_nodes = []
for gemm_idx in range(grouped_gemm_number):
single_gemm_nodes = nodes[gemm_idx]
assert isinstance(dst, Iterable)
single_gemm_dst = dst[gemm_idx]
assert isinstance(single_gemm_nodes, Iterable)
assert isinstance(single_gemm_dst, ir.IRNode)
gemm_nodes.append(single_gemm_nodes[i])
output_names.append(
single_gemm_nodes[i].get_name()
if i < len(single_gemm_nodes) - 1
else single_gemm_dst.get_name()
)
_node = gemm_nodes[gemm_idx]
gemm_nodes[gemm_idx] = (
_node.data if isinstance(_node, ir.ComputedBuffer) else _node
)
def fn(*args):
assert len(args) == 2
assert len(args[0]) == len(var_sizes[0])
assert len(args[1]) == 0
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
if reindexers[i] is not None:
new_args = reindexers[i](new_args) # type: ignore[misc]
for gemm_idx in range(grouped_gemm_number):
V.ops.store(
output_names[gemm_idx],
output_index,
gemm_nodes[gemm_idx].make_loader()(new_args).value,
)
body = LoopBody(
fn,
(list(var_ranges.keys()), ()),
var_ranges,
list(var_ranges.keys()),
tuple(),
)
bodies.append(body)
var_sizes_list.append(var_sizes)
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue()
def store_output(
self,
dst: ir.Buffer,
@ -335,6 +411,43 @@ class CppTemplateKernel(CppKernel):
assert dst.layout == src.layout, f"{dst=}, {src=}"
return ""
def store_outputs(
self,
dst: tuple[ir.Buffer],
src: tuple[ir.IRNode],
orig_src: Optional[tuple[ir.IRNode]] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
offsets: Optional[List[Any]] = None,
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
):
# Grouped GEMM may have multi outputs to be localized
assert isinstance(src, Iterable)
assert isinstance(dst, Iterable)
assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
if offsets:
offsets = parse_expr_with_index_symbols(offsets)
if epilogue_nodes:
assert (
not epilogue_nodes
), "epilogue_nodes not supported for Grouped GEMM yet"
else:
if dst[0].get_name() != src[0].get_name():
copy_list = []
with LocalBufferContext(self.args) as scope:
for _src, _dst in zip(src, dst):
copy_list.append([L.copy(_dst, _src).data.data])
scope.add_local_buffer(_src)
return self.store_grouped_gemm_pointwise_nodes(dst, copy_list)
else:
assert all(
_src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
)
assert all(
_src.get_layout() == _dst.get_layout()
for _src, _dst in zip(src, dst)
)
return ""
class CppTemplateCaller(ir.ChoiceCaller):
"""

View File

@ -2193,9 +2193,12 @@ class PythonWrapperCodegen(CodeGen):
):
return
self.allocated.add(name)
if isinstance(
buffer.get_defining_op(),
(ir.ExternKernelAlloc, ir.MultiOutput),
if (
isinstance(
buffer.get_defining_op(),
(ir.ExternKernelAlloc, ir.MultiOutput),
)
and not buffer.should_allocate()
):
return

View File

@ -884,6 +884,9 @@ class cpp:
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
)
# Enable the Grouped GEMM Fusion
enable_grouped_gemm_template = False
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread
# synchronization and buffers, the maximal number of slices is limited to

View File

@ -38,6 +38,75 @@ if torch._C._has_mkldnn:
_linear_args = [Arg() for _ in range(6)]
_conv_transpose_args = [Arg() for _ in range(11)]
def _is_valid_grouped_gemm_fusion(computation_nodes):
"""
Here we check:
1. More than 1 GEMM nodes has been found.
2. All the GEMM nodes share the same activation.
3. All the GEMM nodes have same weight size but different wgt node.
"""
computation_op = mkldnn._linear_pointwise.default
act = computation_nodes[0].args[0]
wgt = computation_nodes[0].args[1]
wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr]
return len(computation_nodes) >= 2 and all(
(
node.target == computation_op
and node.args[0] == act
and (node.args[1].meta.get("val").size() == wgt_size)
and (node.args[1] != wgt or gemm_idx == 0)
and not node.args[2] # <TODO> support bias through epilogue fusion
)
for gemm_idx, node in enumerate(computation_nodes)
)
def grouped_gemm_pass(graph: torch.fx.Graph):
"""
Group GEMM has multi output nodes which is compilicated to define a Pattern.
Use below way to connect the pattern to the lowering.
TODO: Use MultiOutputPattern, current limitation is the pattern requires
fixed number of output nodes. Extend to support Group GEMM for pattern matcher.
"""
computation_op = mkldnn._linear_pointwise.default
from ..mkldnn_lowerings import grouped_gemm_lowering
for node in graph.find_nodes(op="call_function", target=computation_op):
if (
not node._erased
and isinstance(node.meta.get("val"), torch.Tensor)
and node.meta["val"].device.type == "cpu"
):
act = node.args[0]
users = list(act.users)
if _is_valid_grouped_gemm_fusion(users):
with graph.inserting_before(node):
grouped_gemm_node = graph.create_node(
"call_function",
grouped_gemm_lowering,
(
act,
[user.args[1] for user in users],
[None for _ in users],
),
)
grouped_gemm_node.meta["val"] = [
user.meta["val"] for user in users
]
with graph.inserting_after(grouped_gemm_node):
for gemm_idx, user in enumerate(users):
assert user.target == computation_op
get_item = graph.create_node(
"call_function",
operator.getitem,
(
grouped_gemm_node,
gemm_idx,
),
)
user.replace_all_uses_with(get_item)
graph.erase_node(user)
return
def _conv_call(users=1):
return CallFunction(
mkldnn._convolution_pointwise.default, *_conv_args, _users=users

View File

@ -102,6 +102,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
post_grad_custom_pre_pass
)
if (
config.cpp.enable_grouped_gemm_template
and config.max_autotune
and "CPP" in config.max_autotune_gemm_backends
and torch._C._has_mkldnn
):
from .mkldnn_fusion import grouped_gemm_pass
grouped_gemm_pass(gm.graph)
if config.pattern_matcher:
lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)

View File

@ -4496,6 +4496,18 @@ class CppTemplateBuffer(TemplateBuffer):
super().__init__(layout, inputs, make_kernel_render)
self.template = template
self.choice = choice
self.outputs: Optional[List[Buffer]] = None
def get_layout(self) -> Layout:
if isinstance(self.layout, MultiOutputLayout):
assert isinstance(self.outputs, Iterable)
first_output = self.outputs[0]
assert isinstance(first_output, Buffer)
layout = first_output.layout
assert isinstance(layout, Layout)
return layout
else:
return super().get_layout()
@ir_dataclass(frozen=False)
@ -6832,6 +6844,10 @@ class MultiOutput(ExternKernel):
return self.inputs[0].get_unbacked_symbol_uses()
def should_allocate(self) -> bool:
if len(self.inputs) == 1 and (
isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
):
return True
return False
def get_inputs_that_alias_output(self) -> Sequence[str]:

View File

@ -8,6 +8,7 @@ from torch._inductor.kernel.mm_common import mm_args
from . import ir
from .codegen.cpp_gemm_template import CppGemmTemplate
from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
from .codegen.cpp_utils import create_epilogue_with_attr
from .ir import TensorBox
from .lowering import (
@ -28,6 +29,73 @@ from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotun
from .virtualized import ops, V
def grouped_gemm_lowering(
x: TensorBox,
w: List[TensorBox],
b: List[TensorBox],
attr=None,
scalars=None,
algorithm=None,
layout=None,
):
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
num_gemm = len(w)
assert use_max_autotune()
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
choices: List[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
kwargs = dict(
has_bias=[bias is not None for bias in b],
trans_w=True,
epilogue_creator=None,
act_mapping={num: x for num in range(num_gemm)},
)
input_nodes = [x, *w]
input_nodes.extend([bias for bias in b if bias is not None])
CppGroupedGemmTemplate.add_choices(
choices,
layout,
input_nodes,
**kwargs, # type: ignore[arg-type]
)
assert len(choices) != 0
result = autotune_select_algorithm(
"grouped_gemm",
choices,
input_nodes,
layout,
)
template_buf = result.data.data
return_bufs = [
ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
for gemm_idx in range(num_gemm)
]
template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
template_buf.outputs = return_bufs
return_tensors = [
ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
]
if len(x_size) > 2:
for gemm_idx in range(num_gemm):
return_tensors[gemm_idx] = view(
return_tensors[gemm_idx],
(*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
)
return return_tensors
grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined]
def register_onednn_fusion_ops():
if torch._C._has_mkldnn:
from . import mkldnn_ir

View File

@ -3607,6 +3607,15 @@ class Scheduler:
# the current kernel from where 'allocate' retrieve those decisions.
# We have to make sure there is a non-NULL kernel handler to store
# those inplace update decisions.
if (
isinstance(scheduler_node.node, ir.MultiOutput)
and len(scheduler_node.node.inputs) == 1
and isinstance(scheduler_node.node.inputs[0], ir.CppTemplateBuffer)
):
# <TODO> Remove this code after Fuse MultiOutput and CppTemplateBuffer
return
counters["inductor"]["extern_calls"] += 1
with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
scheduler_node.decide_inplace_update()