mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor][CPP] Enable Grouped GEMM Template (#143796)
**Summary** Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012 - Support flexible number of GEMMs - Share activation across GEMMs - The Grouped GEMM Template supports independent activations - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs - Each GEMM can have a unique weight but same sizes - Each GEMM can have a unique bias or None - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR - Each GEMM have its own epilogues - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear ``` **Example** Here is the example and generated code ``` batch_size = 4 in_features = 512 out_features = 1024 dtype = torch.bfloat16 class M(torch.nn.Module): def __init__(self, bias): super().__init__() self.linear0 = torch.nn.Linear(in_features, out_features, bias=False) self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) def forward(self, x): return self.linear0(x), self.linear1(x) if __name__ == "__main__": with torch.no_grad(): input = torch.randn(batch_size, in_features, dtype=dtype) m = M(bias=bias).to(dtype=dtype).eval() cm = torch.compile(m) act_res = cm(input) ``` Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py **Next Step** - Support Epilogue fusion Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796 Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
parent
35b46a75f1
commit
25de671ea8
|
|
@ -207,7 +207,12 @@ if RUN_CPU:
|
|||
*[
|
||||
BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
||||
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
|
||||
if func.startswith("test_linear_with_pointwise")
|
||||
if func.startswith(
|
||||
(
|
||||
"test_linear_with_pointwise",
|
||||
"test_grouped_linear",
|
||||
)
|
||||
)
|
||||
],
|
||||
BaseTest("test_polar"),
|
||||
BaseTest(
|
||||
|
|
|
|||
|
|
@ -1683,6 +1683,105 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
self.assertEqual(actual, expected, atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
@parametrize("batch_size", (16,))
|
||||
@parametrize("in_features", (52,))
|
||||
@parametrize("out_features", (32,))
|
||||
@parametrize("gemm_num", (2, 3))
|
||||
def test_grouped_linear_invalid(
|
||||
self,
|
||||
batch_size,
|
||||
in_features,
|
||||
out_features,
|
||||
gemm_num,
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, in_feature, out_feature, gemm_num):
|
||||
super().__init__()
|
||||
self.linears = [
|
||||
torch.nn.Linear(in_feature, out_feature + gemm_idx, bias=False)
|
||||
for gemm_idx in range(gemm_num)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
return [linear(x) for linear in self.linears]
|
||||
|
||||
# each linear has different num of out features, thus invaild grouped gemm
|
||||
dtypes = []
|
||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
dtypes.append(torch.bfloat16)
|
||||
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||
dtypes.append(torch.float16)
|
||||
for dtype in dtypes:
|
||||
torch._dynamo.reset()
|
||||
torch._inductor.metrics.reset()
|
||||
counters.clear()
|
||||
mod = M(in_features, out_features, gemm_num).eval()
|
||||
v = torch.randn(batch_size, in_features).to(dtype)
|
||||
with verify(dtype) as (atol, rtol), torch.autocast(
|
||||
device_type="cpu", dtype=dtype
|
||||
), torch.no_grad():
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
# gemm_num independent template instead of grouped gemm template
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_autotune"], gemm_num
|
||||
)
|
||||
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
@parametrize("batch_size", (16,))
|
||||
@parametrize("in_features", (52,))
|
||||
@parametrize("out_features", (32,))
|
||||
@parametrize("input_3d", (False, True))
|
||||
@parametrize("gemm_num", (2, 3))
|
||||
def test_grouped_linear(
|
||||
self,
|
||||
batch_size,
|
||||
in_features,
|
||||
out_features,
|
||||
input_3d,
|
||||
gemm_num,
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, in_feature, out_feature, gemm_num):
|
||||
super().__init__()
|
||||
self.linears = [
|
||||
torch.nn.Linear(in_feature, out_feature, bias=False)
|
||||
for _ in range(gemm_num)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
return [linear(x) for linear in self.linears]
|
||||
|
||||
dtypes = []
|
||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
|
||||
dtypes.append(torch.bfloat16)
|
||||
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
|
||||
dtypes.append(torch.float16)
|
||||
for dtype in dtypes:
|
||||
if dtype == torch.float16 and input_3d:
|
||||
# reduce the number of tests
|
||||
continue
|
||||
torch._dynamo.reset()
|
||||
torch._inductor.metrics.reset()
|
||||
counters.clear()
|
||||
mod = M(in_features, out_features, gemm_num).eval()
|
||||
B = (2, batch_size) if input_3d else (batch_size,)
|
||||
v = torch.randn(*B, in_features).to(dtype)
|
||||
with verify(dtype) as (atol, rtol), torch.autocast(
|
||||
device_type="cpu", dtype=dtype
|
||||
), torch.no_grad():
|
||||
self.common(mod, (v,), atol=atol, rtol=rtol)
|
||||
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
|
||||
|
||||
@inductor_config.patch({"freezing": False})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
@ -2031,6 +2130,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
|
|||
test_quantized_linear_amx_dynamic_shapes = (
|
||||
TestSelectAlgorithm.test_quantized_linear_amx
|
||||
)
|
||||
test_grouped_linear_dynamic_shapes = TestSelectAlgorithm.test_grouped_linear
|
||||
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
|
||||
test_linear_cache_blocking_dynamic_shapes = (
|
||||
TestSelectAlgorithm.test_linear_cache_blocking
|
||||
|
|
|
|||
|
|
@ -500,7 +500,13 @@ class BenchmarkRequest:
|
|||
self.input_tensor_meta = input_tensor_meta
|
||||
|
||||
if isinstance(output_tensor_meta, (tuple, list)):
|
||||
assert len(output_tensor_meta) == 1
|
||||
if len(output_tensor_meta) > 1:
|
||||
# Each output with same meta for Grouped GEMM
|
||||
assert all(
|
||||
getattr(output_tensor_meta[0], attr) == getattr(x, attr)
|
||||
for x in output_tensor_meta
|
||||
for attr in ["device", "dtype", "sizes", "strides", "offset"]
|
||||
)
|
||||
output_tensor_meta = output_tensor_meta[0]
|
||||
self.output_tensor_meta = output_tensor_meta
|
||||
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from ..loop_body import LoopBody
|
|||
from ..scheduler import (
|
||||
BaseSchedulerNode,
|
||||
BaseScheduling,
|
||||
ExternKernelSchedulerNode,
|
||||
ForeachKernelSchedulerNode,
|
||||
FusedSchedulerNode,
|
||||
Scheduler,
|
||||
|
|
@ -4905,7 +4906,24 @@ class CppScheduling(BaseScheduling):
|
|||
epilogue_nodes=epilogue_ir_nodes,
|
||||
)
|
||||
with kernel:
|
||||
for node in [template_node, *epilogue_nodes]:
|
||||
if isinstance(template_node.node, ir.CppTemplateBuffer) and isinstance(
|
||||
template_node.node.layout, ir.MultiOutputLayout
|
||||
):
|
||||
# For Grouped GEMM, allocate buffers for each GEMM
|
||||
assert (
|
||||
len(template_node.outputs) == 1
|
||||
), "Grouped GEMM has 1 output template buffer"
|
||||
for user in template_node.outputs[0].users:
|
||||
assert isinstance(
|
||||
user.node, ExternKernelSchedulerNode
|
||||
), "Grouped GEMM should be with ExternKernelSchedulerNode"
|
||||
assert isinstance(
|
||||
user.node.node, ir.MultiOutput
|
||||
), "Grouped GEMM has multi users with MultiOutput"
|
||||
user.node.mark_run()
|
||||
else:
|
||||
template_node.mark_run() # type: ignore[attr-defined]
|
||||
for node in epilogue_nodes:
|
||||
node.mark_run() # type: ignore[attr-defined]
|
||||
src_code = render()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import contextlib
|
|||
import logging
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -299,9 +299,10 @@ def get_padded_n(n, block_n):
|
|||
return (n + block_n - 1) // block_n * block_n
|
||||
|
||||
|
||||
def transpose_w(
|
||||
W: Union[ir.IRNode, torch.Tensor], trans_w: bool
|
||||
) -> Union[ir.IRNode, torch.Tensor]:
|
||||
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
||||
|
||||
|
||||
def transpose_w(W: _T, trans_w: bool) -> _T:
|
||||
"""
|
||||
Transpose W based on the trans_w flag.
|
||||
"""
|
||||
|
|
@ -317,9 +318,7 @@ def transpose_w(
|
|||
return W
|
||||
|
||||
|
||||
def expand_bias(
|
||||
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
|
||||
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
|
||||
def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
|
||||
"""
|
||||
Expand Bias to the same size of X.
|
||||
"""
|
||||
|
|
@ -336,7 +335,7 @@ def expand_bias(
|
|||
return B
|
||||
|
||||
|
||||
def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.TensorBox]):
|
||||
def prune_tensors(input_nodes: List[ir.IRNode], new_input_nodes: List[ir.IRNode]):
|
||||
"""
|
||||
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
|
||||
"""
|
||||
|
|
@ -798,6 +797,7 @@ class CppGemmTemplate(CppTemplate):
|
|||
trans_w=False,
|
||||
input_indices=None,
|
||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
||||
):
|
||||
if input_indices is None:
|
||||
input_indices = list(range(len(input_nodes)))
|
||||
|
|
@ -1251,6 +1251,7 @@ class CppGemmTemplate(CppTemplate):
|
|||
# --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
|
||||
# Y
|
||||
if epilogue_creators:
|
||||
assert isinstance(template_buffer, ir.IRNode)
|
||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
||||
gemm_output_buffer = ir.Buffer(
|
||||
name=gemm_output_name, layout=template_buffer.layout
|
||||
|
|
@ -1276,14 +1277,17 @@ class CppGemmTemplate(CppTemplate):
|
|||
name=buffer_name, layout=template_buffer.layout
|
||||
)
|
||||
|
||||
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
|
||||
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
|
||||
|
||||
if epilogue_nodes:
|
||||
if not template_buffer_has_other_users:
|
||||
assert isinstance(template_buffer, ir.IRNode)
|
||||
Y_aliases.add(template_buffer.get_name())
|
||||
epilogues.extend(epilogue_nodes)
|
||||
assert Y.get_numel() == epilogues[-1].get_numel()
|
||||
Y = cast(ir.Buffer, epilogues[-1])
|
||||
assert isinstance(template_buffer, ir.Buffer)
|
||||
Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
|
||||
Y,
|
||||
template_buffer,
|
||||
|
|
|
|||
463
torch/_inductor/codegen/cpp_grouped_gemm_template.py
Normal file
463
torch/_inductor/codegen/cpp_grouped_gemm_template.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
import contextlib
|
||||
import logging
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.utils
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..._dynamo.utils import counters
|
||||
from .. import config, ir
|
||||
from ..kernel.mm_common import mm_args
|
||||
from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
|
||||
from ..utils import parallel_num_threads
|
||||
from ..virtualized import V
|
||||
from .cpp import get_export_declaration
|
||||
from .cpp_gemm_template import CppGemmTemplate, expand_bias, prune_tensors, transpose_w
|
||||
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
|
||||
from .cpp_template_kernel import CppTemplateKernel
|
||||
from .cpp_utils import (
|
||||
DTYPE_TO_CPP,
|
||||
GemmBlocking,
|
||||
get_gemm_template_output_and_compute_dtype,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
GEMM_TEMPLATE = r"""
|
||||
{{template.header().getvalue()}}
|
||||
{{micro_gemm.codegen_define(kernel)}}
|
||||
|
||||
extern "C" {{export_declaration}}
|
||||
{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
|
||||
{
|
||||
{{kernel.maybe_codegen_profile()}}
|
||||
{{ template.codegen_blocks(
|
||||
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
|
||||
) }}
|
||||
{%- if num_threads > 1 %}
|
||||
#pragma omp parallel num_threads({{num_threads}})
|
||||
{
|
||||
{{ template.codegen_multi_threads_params()|indent(8, false) }}
|
||||
{%- else %}
|
||||
{
|
||||
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
|
||||
{%- endif %}
|
||||
{{ micro_gemm.codegen_init(kernel) }}
|
||||
{%- set acc_buf_name_list=[] %}
|
||||
{%- set acc_buf_name_prefix = "local_acc_buf_" %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
|
||||
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
|
||||
{%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
|
||||
{%- endfor %}
|
||||
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
|
||||
{{ template.codegen_m_loop_params()|indent(12, false) }}
|
||||
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
|
||||
{{ template.codegen_n_loop_params()|indent(16, false) }}
|
||||
{%- set acc_list=[] %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
|
||||
{{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
|
||||
{%- endfor %}
|
||||
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
|
||||
int64_t k_start = kc * Kr;
|
||||
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
|
||||
{%- set tile_X_list=[] %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
|
||||
{%- endfor %}
|
||||
for (int64_t nci = nc; nci < nc_block_end; nci++) {
|
||||
{%- set tile_W_3d_list=[] %}
|
||||
{%- set tile_W_list=[] %}
|
||||
{%- set acc_slice_list=[] %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set acc_slice_list = acc_slice_list.append(
|
||||
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
|
||||
) %}
|
||||
{%- set tile_W_3d_list = tile_W_3d_list.append(
|
||||
kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
|
||||
) %}
|
||||
{%- endfor %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set tile_W_list = tile_W_list.append(
|
||||
kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
|
||||
) %}
|
||||
{%- endfor %}
|
||||
if (kc == k_block_start) {
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{{ micro_gemm.codegen_call(
|
||||
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
|
||||
)|indent(28, false) }}
|
||||
{%- endfor %}
|
||||
} else {
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{{ micro_gemm.codegen_call(
|
||||
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
|
||||
)|indent(28, false) }}
|
||||
{%- endfor %}
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
{%- set tile_acc_list = [] %}
|
||||
{%- set tile_Y_list = [] %}
|
||||
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
|
||||
{%- set tile_acc_list = tile_acc_list.append(
|
||||
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
|
||||
) %}
|
||||
{%- set tile_Y_list = tile_Y_list.append(
|
||||
kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
|
||||
) %}
|
||||
{%- endfor %}
|
||||
{{ kernel.store_outputs(
|
||||
tile_Y_list, tile_acc_list, GemmOuts, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
|
||||
)|indent(20, false)
|
||||
}}
|
||||
}
|
||||
}
|
||||
}
|
||||
{{ micro_gemm.codegen_finalize(kernel) }}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> List[ir.IRNode]:
|
||||
act_deduplicated = []
|
||||
act_deduplicated_name: OrderedSet[str] = OrderedSet()
|
||||
for act_idx in range(len(act_mapping.values())):
|
||||
act = act_mapping[act_idx]
|
||||
if act.get_name() not in act_deduplicated_name:
|
||||
act_deduplicated.append(act)
|
||||
act_deduplicated_name.add(act.get_name())
|
||||
return act_deduplicated
|
||||
|
||||
|
||||
class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||
def __init__(
|
||||
self,
|
||||
input_nodes: List[ir.IRNode],
|
||||
layout: ir.Layout,
|
||||
num_threads: int,
|
||||
register_blocking: GemmBlocking,
|
||||
beta: int = 1,
|
||||
alpha: int = 1,
|
||||
has_bias: bool = False,
|
||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
||||
gemm_grouped_num: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Template for Group of GEMMs:
|
||||
* Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
|
||||
for their A, B, and C matrices.
|
||||
* Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
|
||||
* In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
|
||||
This behavior can be extended in the future if needed.
|
||||
"""
|
||||
super().__init__(
|
||||
input_nodes,
|
||||
layout,
|
||||
num_threads,
|
||||
register_blocking,
|
||||
beta,
|
||||
alpha,
|
||||
has_bias,
|
||||
epilogue_creator,
|
||||
)
|
||||
self.act_mapping = act_mapping
|
||||
self.gemm_grouped_num = gemm_grouped_num
|
||||
self.output_node: List[ir.Buffer] = [
|
||||
ir.Buffer(name="buf_out" + str(idx), layout=layout)
|
||||
for idx in range(gemm_grouped_num)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _fake_get_dtype(fake_outs: List[ir.Buffer]) -> Callable[[str], torch.dtype]:
|
||||
_get_dtype_real = V.graph.get_dtype
|
||||
|
||||
def get_dtype(name: str) -> torch.dtype:
|
||||
for fake_out in fake_outs:
|
||||
if name == fake_out.get_name():
|
||||
return fake_out.get_dtype()
|
||||
return _get_dtype_real(name)
|
||||
|
||||
return get_dtype
|
||||
|
||||
@classmethod
|
||||
def add_choices(
|
||||
cls,
|
||||
choices: List[ChoiceCaller],
|
||||
layout: ir.Layout,
|
||||
input_nodes: List[ir.IRNode],
|
||||
beta: int = 1,
|
||||
alpha: int = 1,
|
||||
has_bias: tuple[bool, ...] = (False, False),
|
||||
trans_w: bool = False,
|
||||
input_indices: Optional[List[int]] = None,
|
||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
|
||||
) -> DataProcessorTemplateWrapper:
|
||||
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
|
||||
gemm_grouped_num = len(has_bias)
|
||||
assert act_mapping
|
||||
act_deduplicated = get_deduplicated_act(act_mapping)
|
||||
wgt_start_idx = len(act_deduplicated)
|
||||
bias_start_idx = wgt_start_idx + gemm_grouped_num
|
||||
input_indices = list(range(len(input_nodes)))
|
||||
|
||||
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
|
||||
_U = TypeVar("_U", ir.Layout, torch.Tensor)
|
||||
|
||||
def reorder_and_filter(
|
||||
inputs: List[_T],
|
||||
layout_or_out: _U,
|
||||
) -> tuple[List[_T], _U]:
|
||||
assert input_indices is not None, "input_indices must be set"
|
||||
return [inputs[idx] for idx in input_indices], layout_or_out
|
||||
|
||||
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
|
||||
|
||||
def maybe_to_dense(
|
||||
inputs: List[_T],
|
||||
layout_or_out: _U,
|
||||
) -> tuple[List[_T], _U]:
|
||||
new_inputs = list(inputs)
|
||||
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||
if isinstance(inputs[idx], torch.Tensor):
|
||||
W = inputs[idx]
|
||||
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
|
||||
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
def normalize_shapes(
|
||||
inputs: List[_T],
|
||||
layout_or_out: _U,
|
||||
) -> tuple[List[_T], _U]:
|
||||
new_inputs: List[_T] = list(inputs)
|
||||
if not trans_w:
|
||||
return new_inputs, layout_or_out
|
||||
X = new_inputs[0]
|
||||
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||
new_input = new_inputs[wgt_idx]
|
||||
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
|
||||
for bias_idx in range(bias_start_idx, len(new_inputs)):
|
||||
new_bias = expand_bias(new_inputs[bias_idx], X)
|
||||
assert new_bias is not None
|
||||
new_inputs[bias_idx] = new_bias
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
num_threads = parallel_num_threads()
|
||||
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
|
||||
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
|
||||
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
||||
new_inputs[0].get_dtype()
|
||||
)
|
||||
micro_gemm = create_micro_gemm(
|
||||
"micro_gemm",
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
input_dtype=new_inputs[0].get_dtype(),
|
||||
input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
alpha=alpha,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
assert micro_gemm is not None
|
||||
_, block_n, _ = micro_gemm.register_blocking
|
||||
new_size, padded_n = cls.get_padded_size(
|
||||
n, block_n, k, should_block_weight=True
|
||||
)
|
||||
padding = padded_n - n
|
||||
|
||||
def pack_weight(
|
||||
inputs: List[_T],
|
||||
layout_or_out: _U,
|
||||
) -> tuple[List[_T], _U]:
|
||||
new_W_list = []
|
||||
new_inputs = list(inputs)
|
||||
W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
|
||||
for W in W_list:
|
||||
blocked_w = cls.block_weight(W, new_size, padding)
|
||||
new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
|
||||
new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
|
||||
return new_inputs, layout_or_out
|
||||
|
||||
def preprocessor(
|
||||
inputs: List[_T],
|
||||
layout: _U,
|
||||
) -> tuple[List[_T], _U]:
|
||||
return pack_weight(
|
||||
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
|
||||
)
|
||||
|
||||
def postprocessor(output: _T) -> _T:
|
||||
if isinstance(output, ir.TensorBox):
|
||||
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
|
||||
assert isinstance(template_buffer, ir.CppTemplateBuffer)
|
||||
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
|
||||
W_nodes = new_input_nodes[
|
||||
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
||||
]
|
||||
W_tensor = []
|
||||
for W_node in W_nodes:
|
||||
assert W_node.get_name() in V.graph.constants
|
||||
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||
new_input_nodes[
|
||||
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
||||
] = W_tensor # type: ignore[assignment]
|
||||
new_input_nodes, _ = pack_weight(
|
||||
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
||||
)
|
||||
# Prune unused tensors
|
||||
prune_tensors(input_nodes, new_input_nodes)
|
||||
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
|
||||
W_packed = new_input_nodes[idx]
|
||||
assert isinstance(W_packed, torch.Tensor)
|
||||
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
||||
template_buffer.inputs[
|
||||
idx
|
||||
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||
return output
|
||||
|
||||
template = DataProcessorTemplateWrapper(
|
||||
CppGroupedGemmTemplate,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
input_nodes=input_nodes,
|
||||
layout=layout,
|
||||
num_threads=num_threads,
|
||||
register_blocking=micro_gemm.register_blocking,
|
||||
beta=beta,
|
||||
alpha=alpha,
|
||||
has_bias=has_bias,
|
||||
epilogue_creator=epilogue_creator,
|
||||
act_mapping=act_mapping,
|
||||
gemm_grouped_num=gemm_grouped_num,
|
||||
)
|
||||
template.maybe_append_choice(choices)
|
||||
return template
|
||||
|
||||
def render( # type: ignore[override,return,no-untyped-def]
|
||||
self,
|
||||
kernel: CppTemplateKernel,
|
||||
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
|
||||
flag_template_buffer_has_other_users: Optional[bool] = None,
|
||||
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
assert self.act_mapping
|
||||
act_deduplicated = get_deduplicated_act(self.act_mapping)
|
||||
wgt_start_idx = len(act_deduplicated)
|
||||
bias_start_idx = wgt_start_idx + self.gemm_grouped_num
|
||||
X_list = list(self.act_mapping.values())
|
||||
W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
|
||||
inp_list = []
|
||||
cur_idx = bias_start_idx
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
inp = None
|
||||
if self.has_bias[inp_idx]:
|
||||
inp = self.input_nodes[cur_idx]
|
||||
cur_idx += 1
|
||||
inp_list.append(inp)
|
||||
|
||||
Y_list = self.output_node
|
||||
if template_buffer_node is not None:
|
||||
W_list = template_buffer_node.inputs[
|
||||
wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
|
||||
]
|
||||
assert isinstance(template_buffer_node.outputs, List)
|
||||
Y_list = template_buffer_node.outputs
|
||||
counters["inductor"]["cpp_grouped_gemm_template"] += 1
|
||||
|
||||
template_buffer = Y_list[0]
|
||||
fake_buffers: List[ir.Buffer] = []
|
||||
Y_2d_list = Y_list
|
||||
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
||||
X_list[0].get_dtype()
|
||||
)
|
||||
micro_gemm = create_micro_gemm(
|
||||
f"{kernel.kernel_name}_micro_gemm",
|
||||
self.m,
|
||||
self.n,
|
||||
self.k,
|
||||
input_dtype=X_list[0].get_dtype(),
|
||||
input2_dtype=W_list[0].get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
compute_dtype=compute_dtype,
|
||||
alpha=self.alpha,
|
||||
num_threads=self.num_threads,
|
||||
)
|
||||
assert micro_gemm is not None
|
||||
assert self.register_blocking == micro_gemm.register_blocking
|
||||
self.log_blockings()
|
||||
if isinstance(micro_gemm, CppMicroGemmAMX):
|
||||
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
|
||||
|
||||
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
|
||||
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||
|
||||
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
|
||||
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||
|
||||
epilogues: List[ir.IRNode] = []
|
||||
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
|
||||
gemm_output_buffers: list[ir.Buffer] = []
|
||||
for out_buf_idx in range(self.gemm_grouped_num):
|
||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
|
||||
out_buf_idx
|
||||
)
|
||||
gemm_output_buffers.append(
|
||||
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
||||
)
|
||||
|
||||
assert (
|
||||
not self.epilogue_creator and not epilogue_nodes
|
||||
), "Epilogue fusion is not implemented yet in Grouped GEMM Template"
|
||||
|
||||
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
||||
for x_idx in range(wgt_start_idx):
|
||||
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
|
||||
for w_idx in range(self.gemm_grouped_num):
|
||||
kernel_args["W" + str(w_idx)] = W_list[w_idx]
|
||||
for inp_idx in range(self.gemm_grouped_num):
|
||||
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
|
||||
|
||||
options = dict(
|
||||
N=self.n,
|
||||
K=self.k,
|
||||
PADDED_N=self.padded_n,
|
||||
aliases={},
|
||||
beta=self.beta,
|
||||
alpha=self.alpha,
|
||||
num_threads=self.num_threads,
|
||||
micro_gemm=micro_gemm,
|
||||
is_dynamic_M=self.is_dynamic_M,
|
||||
template=self,
|
||||
kernel=kernel,
|
||||
export_declaration=get_export_declaration(),
|
||||
acc_buf_dtype=torch.float,
|
||||
DTYPE_TO_CPP=DTYPE_TO_CPP,
|
||||
L1_cache_size=L1_cache_size,
|
||||
L2_cache_size=L2_cache_size,
|
||||
config=config,
|
||||
epilogue_nodes=epilogues,
|
||||
GemmOuts=gemm_output_buffers,
|
||||
reindexers=reindexers,
|
||||
kernel_args=kernel_args,
|
||||
X_list=X_list,
|
||||
W_list=W_list,
|
||||
gemm_grouped_num=self.gemm_grouped_num,
|
||||
Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
|
||||
Y_2d_list=Y_2d_list,
|
||||
)
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
|
||||
)
|
||||
return self._template_from_string(GEMM_TEMPLATE).render(**options)
|
||||
|
|
@ -4,7 +4,7 @@ import functools
|
|||
import itertools
|
||||
import logging
|
||||
import sys
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Iterable, List, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
|
@ -33,7 +33,9 @@ class CppTemplate(KernelTemplate):
|
|||
) -> None:
|
||||
super().__init__(name)
|
||||
self.input_nodes = input_nodes
|
||||
self.output_node: ir.Buffer = ir.Buffer(name="buf_out", layout=layout)
|
||||
self.output_node: Union[ir.Buffer, List[ir.Buffer]] = ir.Buffer(
|
||||
name="buf_out", layout=layout
|
||||
)
|
||||
self.layout = layout
|
||||
self.num_threads = num_threads
|
||||
self.epilogue_creator = epilogue_creator
|
||||
|
|
@ -57,7 +59,10 @@ class CppTemplate(KernelTemplate):
|
|||
expected_args = list(
|
||||
unique(input_node.get_name() for input_node in self.input_nodes)
|
||||
)
|
||||
expected_args.extend([self.output_node.get_name()])
|
||||
if isinstance(self.output_node, Iterable):
|
||||
expected_args.extend([node.get_name() for node in self.output_node])
|
||||
else:
|
||||
expected_args.extend([self.output_node.get_name()])
|
||||
assert list(call_args)[: len(expected_args)] == expected_args, (
|
||||
call_args,
|
||||
expected_args,
|
||||
|
|
@ -102,7 +107,9 @@ class CppTemplate(KernelTemplate):
|
|||
kernel_hash_name,
|
||||
self.name,
|
||||
self.input_nodes,
|
||||
self.output_node.get_layout(),
|
||||
self.output_node[0].get_layout()
|
||||
if isinstance(self.output_node, Iterable)
|
||||
else self.output_node.get_layout(),
|
||||
make_kernel_render,
|
||||
bmreq,
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import sympy
|
||||
from sympy.parsing.sympy_parser import parse_expr
|
||||
|
|
@ -278,6 +278,82 @@ class CppTemplateKernel(CppKernel):
|
|||
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
||||
return kernel_group.loops_code.getvalue()
|
||||
|
||||
def store_grouped_gemm_pointwise_nodes(
|
||||
self,
|
||||
dst: tuple[ir.Buffer],
|
||||
nodes: List[List[ir.IRNode]],
|
||||
offsets: Optional[List[sympy.Expr]] = None,
|
||||
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
||||
) -> str:
|
||||
ref_dst = dst[0]
|
||||
var_sizes = (tuple(ref_dst.get_size()), ())
|
||||
var_ranges = {
|
||||
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
|
||||
for i, sz in enumerate(var_sizes[0])
|
||||
}
|
||||
if not offsets:
|
||||
offsets = [sympy.S.Zero] * len(var_sizes[0])
|
||||
if not reindexers:
|
||||
reindexers = [None] * len(nodes)
|
||||
assert len(offsets) == len(var_sizes[0])
|
||||
output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
|
||||
kernel_group = KernelGroup()
|
||||
kernel_group.args = self.args
|
||||
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||
bodies = []
|
||||
var_sizes_list = []
|
||||
assert isinstance(nodes[0], Iterable)
|
||||
grouped_gemm_number = len(nodes)
|
||||
epilogue_nodes = nodes[0]
|
||||
assert isinstance(epilogue_nodes, Iterable)
|
||||
for i, _ in enumerate(epilogue_nodes):
|
||||
output_names = []
|
||||
gemm_nodes = []
|
||||
for gemm_idx in range(grouped_gemm_number):
|
||||
single_gemm_nodes = nodes[gemm_idx]
|
||||
assert isinstance(dst, Iterable)
|
||||
single_gemm_dst = dst[gemm_idx]
|
||||
assert isinstance(single_gemm_nodes, Iterable)
|
||||
assert isinstance(single_gemm_dst, ir.IRNode)
|
||||
gemm_nodes.append(single_gemm_nodes[i])
|
||||
output_names.append(
|
||||
single_gemm_nodes[i].get_name()
|
||||
if i < len(single_gemm_nodes) - 1
|
||||
else single_gemm_dst.get_name()
|
||||
)
|
||||
_node = gemm_nodes[gemm_idx]
|
||||
gemm_nodes[gemm_idx] = (
|
||||
_node.data if isinstance(_node, ir.ComputedBuffer) else _node
|
||||
)
|
||||
|
||||
def fn(*args):
|
||||
assert len(args) == 2
|
||||
assert len(args[0]) == len(var_sizes[0])
|
||||
assert len(args[1]) == 0
|
||||
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
|
||||
if reindexers[i] is not None:
|
||||
new_args = reindexers[i](new_args) # type: ignore[misc]
|
||||
for gemm_idx in range(grouped_gemm_number):
|
||||
V.ops.store(
|
||||
output_names[gemm_idx],
|
||||
output_index,
|
||||
gemm_nodes[gemm_idx].make_loader()(new_args).value,
|
||||
)
|
||||
|
||||
body = LoopBody(
|
||||
fn,
|
||||
(list(var_ranges.keys()), ()),
|
||||
var_ranges,
|
||||
list(var_ranges.keys()),
|
||||
tuple(),
|
||||
)
|
||||
bodies.append(body)
|
||||
var_sizes_list.append(var_sizes)
|
||||
|
||||
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
|
||||
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
|
||||
return kernel_group.loops_code.getvalue()
|
||||
|
||||
def store_output(
|
||||
self,
|
||||
dst: ir.Buffer,
|
||||
|
|
@ -335,6 +411,43 @@ class CppTemplateKernel(CppKernel):
|
|||
assert dst.layout == src.layout, f"{dst=}, {src=}"
|
||||
return ""
|
||||
|
||||
def store_outputs(
|
||||
self,
|
||||
dst: tuple[ir.Buffer],
|
||||
src: tuple[ir.IRNode],
|
||||
orig_src: Optional[tuple[ir.IRNode]] = None,
|
||||
epilogue_nodes: Optional[List[ir.IRNode]] = None,
|
||||
offsets: Optional[List[Any]] = None,
|
||||
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
|
||||
):
|
||||
# Grouped GEMM may have multi outputs to be localized
|
||||
assert isinstance(src, Iterable)
|
||||
assert isinstance(dst, Iterable)
|
||||
assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
|
||||
if offsets:
|
||||
offsets = parse_expr_with_index_symbols(offsets)
|
||||
if epilogue_nodes:
|
||||
assert (
|
||||
not epilogue_nodes
|
||||
), "epilogue_nodes not supported for Grouped GEMM yet"
|
||||
else:
|
||||
if dst[0].get_name() != src[0].get_name():
|
||||
copy_list = []
|
||||
with LocalBufferContext(self.args) as scope:
|
||||
for _src, _dst in zip(src, dst):
|
||||
copy_list.append([L.copy(_dst, _src).data.data])
|
||||
scope.add_local_buffer(_src)
|
||||
return self.store_grouped_gemm_pointwise_nodes(dst, copy_list)
|
||||
else:
|
||||
assert all(
|
||||
_src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
|
||||
)
|
||||
assert all(
|
||||
_src.get_layout() == _dst.get_layout()
|
||||
for _src, _dst in zip(src, dst)
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
class CppTemplateCaller(ir.ChoiceCaller):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -2193,9 +2193,12 @@ class PythonWrapperCodegen(CodeGen):
|
|||
):
|
||||
return
|
||||
self.allocated.add(name)
|
||||
if isinstance(
|
||||
buffer.get_defining_op(),
|
||||
(ir.ExternKernelAlloc, ir.MultiOutput),
|
||||
if (
|
||||
isinstance(
|
||||
buffer.get_defining_op(),
|
||||
(ir.ExternKernelAlloc, ir.MultiOutput),
|
||||
)
|
||||
and not buffer.should_allocate()
|
||||
):
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -884,6 +884,9 @@ class cpp:
|
|||
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
|
||||
)
|
||||
|
||||
# Enable the Grouped GEMM Fusion
|
||||
enable_grouped_gemm_template = False
|
||||
|
||||
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
|
||||
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread
|
||||
# synchronization and buffers, the maximal number of slices is limited to
|
||||
|
|
|
|||
|
|
@ -38,6 +38,75 @@ if torch._C._has_mkldnn:
|
|||
_linear_args = [Arg() for _ in range(6)]
|
||||
_conv_transpose_args = [Arg() for _ in range(11)]
|
||||
|
||||
def _is_valid_grouped_gemm_fusion(computation_nodes):
|
||||
"""
|
||||
Here we check:
|
||||
1. More than 1 GEMM nodes has been found.
|
||||
2. All the GEMM nodes share the same activation.
|
||||
3. All the GEMM nodes have same weight size but different wgt node.
|
||||
"""
|
||||
computation_op = mkldnn._linear_pointwise.default
|
||||
act = computation_nodes[0].args[0]
|
||||
wgt = computation_nodes[0].args[1]
|
||||
wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr]
|
||||
return len(computation_nodes) >= 2 and all(
|
||||
(
|
||||
node.target == computation_op
|
||||
and node.args[0] == act
|
||||
and (node.args[1].meta.get("val").size() == wgt_size)
|
||||
and (node.args[1] != wgt or gemm_idx == 0)
|
||||
and not node.args[2] # <TODO> support bias through epilogue fusion
|
||||
)
|
||||
for gemm_idx, node in enumerate(computation_nodes)
|
||||
)
|
||||
|
||||
def grouped_gemm_pass(graph: torch.fx.Graph):
|
||||
"""
|
||||
Group GEMM has multi output nodes which is compilicated to define a Pattern.
|
||||
Use below way to connect the pattern to the lowering.
|
||||
TODO: Use MultiOutputPattern, current limitation is the pattern requires
|
||||
fixed number of output nodes. Extend to support Group GEMM for pattern matcher.
|
||||
"""
|
||||
computation_op = mkldnn._linear_pointwise.default
|
||||
from ..mkldnn_lowerings import grouped_gemm_lowering
|
||||
|
||||
for node in graph.find_nodes(op="call_function", target=computation_op):
|
||||
if (
|
||||
not node._erased
|
||||
and isinstance(node.meta.get("val"), torch.Tensor)
|
||||
and node.meta["val"].device.type == "cpu"
|
||||
):
|
||||
act = node.args[0]
|
||||
users = list(act.users)
|
||||
if _is_valid_grouped_gemm_fusion(users):
|
||||
with graph.inserting_before(node):
|
||||
grouped_gemm_node = graph.create_node(
|
||||
"call_function",
|
||||
grouped_gemm_lowering,
|
||||
(
|
||||
act,
|
||||
[user.args[1] for user in users],
|
||||
[None for _ in users],
|
||||
),
|
||||
)
|
||||
grouped_gemm_node.meta["val"] = [
|
||||
user.meta["val"] for user in users
|
||||
]
|
||||
with graph.inserting_after(grouped_gemm_node):
|
||||
for gemm_idx, user in enumerate(users):
|
||||
assert user.target == computation_op
|
||||
get_item = graph.create_node(
|
||||
"call_function",
|
||||
operator.getitem,
|
||||
(
|
||||
grouped_gemm_node,
|
||||
gemm_idx,
|
||||
),
|
||||
)
|
||||
user.replace_all_uses_with(get_item)
|
||||
graph.erase_node(user)
|
||||
return
|
||||
|
||||
def _conv_call(users=1):
|
||||
return CallFunction(
|
||||
mkldnn._convolution_pointwise.default, *_conv_args, _users=users
|
||||
|
|
|
|||
|
|
@ -102,6 +102,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
post_grad_custom_pre_pass
|
||||
)
|
||||
|
||||
if (
|
||||
config.cpp.enable_grouped_gemm_template
|
||||
and config.max_autotune
|
||||
and "CPP" in config.max_autotune_gemm_backends
|
||||
and torch._C._has_mkldnn
|
||||
):
|
||||
from .mkldnn_fusion import grouped_gemm_pass
|
||||
|
||||
grouped_gemm_pass(gm.graph)
|
||||
|
||||
if config.pattern_matcher:
|
||||
lazy_init()
|
||||
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
||||
|
|
|
|||
|
|
@ -4496,6 +4496,18 @@ class CppTemplateBuffer(TemplateBuffer):
|
|||
super().__init__(layout, inputs, make_kernel_render)
|
||||
self.template = template
|
||||
self.choice = choice
|
||||
self.outputs: Optional[List[Buffer]] = None
|
||||
|
||||
def get_layout(self) -> Layout:
|
||||
if isinstance(self.layout, MultiOutputLayout):
|
||||
assert isinstance(self.outputs, Iterable)
|
||||
first_output = self.outputs[0]
|
||||
assert isinstance(first_output, Buffer)
|
||||
layout = first_output.layout
|
||||
assert isinstance(layout, Layout)
|
||||
return layout
|
||||
else:
|
||||
return super().get_layout()
|
||||
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
|
|
@ -6832,6 +6844,10 @@ class MultiOutput(ExternKernel):
|
|||
return self.inputs[0].get_unbacked_symbol_uses()
|
||||
|
||||
def should_allocate(self) -> bool:
|
||||
if len(self.inputs) == 1 and (
|
||||
isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_inputs_that_alias_output(self) -> Sequence[str]:
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from torch._inductor.kernel.mm_common import mm_args
|
|||
|
||||
from . import ir
|
||||
from .codegen.cpp_gemm_template import CppGemmTemplate
|
||||
from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
|
||||
from .codegen.cpp_utils import create_epilogue_with_attr
|
||||
from .ir import TensorBox
|
||||
from .lowering import (
|
||||
|
|
@ -28,6 +29,73 @@ from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotun
|
|||
from .virtualized import ops, V
|
||||
|
||||
|
||||
def grouped_gemm_lowering(
|
||||
x: TensorBox,
|
||||
w: List[TensorBox],
|
||||
b: List[TensorBox],
|
||||
attr=None,
|
||||
scalars=None,
|
||||
algorithm=None,
|
||||
layout=None,
|
||||
):
|
||||
x_size = x.get_size()
|
||||
if len(x_size) > 2:
|
||||
# GEMM template needs 2D input, normalize input shape here
|
||||
x = view(x, [-1, x_size[-1]])
|
||||
num_gemm = len(w)
|
||||
|
||||
assert use_max_autotune()
|
||||
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
|
||||
|
||||
choices: List[ChoiceCaller] = []
|
||||
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
||||
|
||||
kwargs = dict(
|
||||
has_bias=[bias is not None for bias in b],
|
||||
trans_w=True,
|
||||
epilogue_creator=None,
|
||||
act_mapping={num: x for num in range(num_gemm)},
|
||||
)
|
||||
|
||||
input_nodes = [x, *w]
|
||||
input_nodes.extend([bias for bias in b if bias is not None])
|
||||
|
||||
CppGroupedGemmTemplate.add_choices(
|
||||
choices,
|
||||
layout,
|
||||
input_nodes,
|
||||
**kwargs, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
assert len(choices) != 0
|
||||
result = autotune_select_algorithm(
|
||||
"grouped_gemm",
|
||||
choices,
|
||||
input_nodes,
|
||||
layout,
|
||||
)
|
||||
template_buf = result.data.data
|
||||
return_bufs = [
|
||||
ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
|
||||
for gemm_idx in range(num_gemm)
|
||||
]
|
||||
template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
|
||||
template_buf.outputs = return_bufs
|
||||
return_tensors = [
|
||||
ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
|
||||
]
|
||||
if len(x_size) > 2:
|
||||
for gemm_idx in range(num_gemm):
|
||||
return_tensors[gemm_idx] = view(
|
||||
return_tensors[gemm_idx],
|
||||
(*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
|
||||
)
|
||||
return return_tensors
|
||||
|
||||
|
||||
grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def register_onednn_fusion_ops():
|
||||
if torch._C._has_mkldnn:
|
||||
from . import mkldnn_ir
|
||||
|
|
|
|||
|
|
@ -3607,6 +3607,15 @@ class Scheduler:
|
|||
# the current kernel from where 'allocate' retrieve those decisions.
|
||||
# We have to make sure there is a non-NULL kernel handler to store
|
||||
# those inplace update decisions.
|
||||
|
||||
if (
|
||||
isinstance(scheduler_node.node, ir.MultiOutput)
|
||||
and len(scheduler_node.node.inputs) == 1
|
||||
and isinstance(scheduler_node.node.inputs[0], ir.CppTemplateBuffer)
|
||||
):
|
||||
# <TODO> Remove this code after Fuse MultiOutput and CppTemplateBuffer
|
||||
return
|
||||
|
||||
counters["inductor"]["extern_calls"] += 1
|
||||
with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
|
||||
scheduler_node.decide_inplace_update()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user