[Inductor][CPP] Enable Grouped GEMM Template (#143796)

**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012

- Support flexible number of GEMMs
- Share activation across GEMMs
  - The Grouped GEMM Template supports independent activations
  - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
  - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
  - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```

**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
        self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)

    def forward(self, x):
        return self.linear0(x), self.linear1(x)

if __name__ == "__main__":
    with torch.no_grad():
        input = torch.randn(batch_size, in_features, dtype=dtype)
        m = M(bias=bias).to(dtype=dtype).eval()
        cm = torch.compile(m)
        act_res = cm(input)
```

Generated Code:  https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py

**Next Step**

- Support Epilogue fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
This commit is contained in:
leslie-fang-intel 2025-01-13 00:03:44 -08:00 committed by PyTorch MergeBot
parent 35b46a75f1
commit 25de671ea8
15 changed files with 913 additions and 19 deletions

View File

@ -207,7 +207,12 @@ if RUN_CPU:
*[ *[
BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU()) BaseTest(func, "", test_cpu_select_algorithm.TestSelectAlgorithmCPU())
for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU()) for func in dir(test_cpu_select_algorithm.TestSelectAlgorithmCPU())
if func.startswith("test_linear_with_pointwise") if func.startswith(
(
"test_linear_with_pointwise",
"test_grouped_linear",
)
)
], ],
BaseTest("test_polar"), BaseTest("test_polar"),
BaseTest( BaseTest(

View File

@ -1683,6 +1683,105 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
self.assertEqual(actual, expected, atol=atol, rtol=rtol) self.assertEqual(actual, expected, atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear_invalid(
self,
batch_size,
in_features,
out_features,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature + gemm_idx, bias=False)
for gemm_idx in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
# each linear has different num of out features, thus invaild grouped gemm
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
v = torch.randn(batch_size, in_features).to(dtype)
with verify(dtype) as (atol, rtol), torch.autocast(
device_type="cpu", dtype=dtype
), torch.no_grad():
self.common(mod, (v,), atol=atol, rtol=rtol)
# gemm_num independent template instead of grouped gemm template
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"], gemm_num
)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 0)
@inductor_config.patch({"freezing": True})
@inductor_config.patch({"cpp.enable_grouped_gemm_template": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("batch_size", (16,))
@parametrize("in_features", (52,))
@parametrize("out_features", (32,))
@parametrize("input_3d", (False, True))
@parametrize("gemm_num", (2, 3))
def test_grouped_linear(
self,
batch_size,
in_features,
out_features,
input_3d,
gemm_num,
):
class M(torch.nn.Module):
def __init__(self, in_feature, out_feature, gemm_num):
super().__init__()
self.linears = [
torch.nn.Linear(in_feature, out_feature, bias=False)
for _ in range(gemm_num)
]
def forward(self, x):
return [linear(x) for linear in self.linears]
dtypes = []
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
dtypes.append(torch.bfloat16)
if torch.ops.mkldnn._is_mkldnn_fp16_supported():
dtypes.append(torch.float16)
for dtype in dtypes:
if dtype == torch.float16 and input_3d:
# reduce the number of tests
continue
torch._dynamo.reset()
torch._inductor.metrics.reset()
counters.clear()
mod = M(in_features, out_features, gemm_num).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype)
with verify(dtype) as (atol, rtol), torch.autocast(
device_type="cpu", dtype=dtype
), torch.no_grad():
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["cpp_grouped_gemm_template"], 1)
@inductor_config.patch({"freezing": False}) @inductor_config.patch({"freezing": False})
@patches @patches
@torch.no_grad @torch.no_grad
@ -2031,6 +2130,7 @@ class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
test_quantized_linear_amx_dynamic_shapes = ( test_quantized_linear_amx_dynamic_shapes = (
TestSelectAlgorithm.test_quantized_linear_amx TestSelectAlgorithm.test_quantized_linear_amx
) )
test_grouped_linear_dynamic_shapes = TestSelectAlgorithm.test_grouped_linear
test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing test_linear_k_slicing_dynamic_shapes = TestSelectAlgorithm.test_linear_k_slicing
test_linear_cache_blocking_dynamic_shapes = ( test_linear_cache_blocking_dynamic_shapes = (
TestSelectAlgorithm.test_linear_cache_blocking TestSelectAlgorithm.test_linear_cache_blocking

View File

@ -500,7 +500,13 @@ class BenchmarkRequest:
self.input_tensor_meta = input_tensor_meta self.input_tensor_meta = input_tensor_meta
if isinstance(output_tensor_meta, (tuple, list)): if isinstance(output_tensor_meta, (tuple, list)):
assert len(output_tensor_meta) == 1 if len(output_tensor_meta) > 1:
# Each output with same meta for Grouped GEMM
assert all(
getattr(output_tensor_meta[0], attr) == getattr(x, attr)
for x in output_tensor_meta
for attr in ["device", "dtype", "sizes", "strides", "offset"]
)
output_tensor_meta = output_tensor_meta[0] output_tensor_meta = output_tensor_meta[0]
self.output_tensor_meta = output_tensor_meta self.output_tensor_meta = output_tensor_meta

View File

@ -26,6 +26,7 @@ from ..loop_body import LoopBody
from ..scheduler import ( from ..scheduler import (
BaseSchedulerNode, BaseSchedulerNode,
BaseScheduling, BaseScheduling,
ExternKernelSchedulerNode,
ForeachKernelSchedulerNode, ForeachKernelSchedulerNode,
FusedSchedulerNode, FusedSchedulerNode,
Scheduler, Scheduler,
@ -4905,7 +4906,24 @@ class CppScheduling(BaseScheduling):
epilogue_nodes=epilogue_ir_nodes, epilogue_nodes=epilogue_ir_nodes,
) )
with kernel: with kernel:
for node in [template_node, *epilogue_nodes]: if isinstance(template_node.node, ir.CppTemplateBuffer) and isinstance(
template_node.node.layout, ir.MultiOutputLayout
):
# For Grouped GEMM, allocate buffers for each GEMM
assert (
len(template_node.outputs) == 1
), "Grouped GEMM has 1 output template buffer"
for user in template_node.outputs[0].users:
assert isinstance(
user.node, ExternKernelSchedulerNode
), "Grouped GEMM should be with ExternKernelSchedulerNode"
assert isinstance(
user.node.node, ir.MultiOutput
), "Grouped GEMM has multi users with MultiOutput"
user.node.mark_run()
else:
template_node.mark_run() # type: ignore[attr-defined]
for node in epilogue_nodes:
node.mark_run() # type: ignore[attr-defined] node.mark_run() # type: ignore[attr-defined]
src_code = render() src_code = render()

View File

@ -3,7 +3,7 @@ import contextlib
import logging import logging
import math import math
from functools import lru_cache from functools import lru_cache
from typing import Any, Callable, cast, Dict, List, Optional, Union from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -299,9 +299,10 @@ def get_padded_n(n, block_n):
return (n + block_n - 1) // block_n * block_n return (n + block_n - 1) // block_n * block_n
def transpose_w( _T = TypeVar("_T", ir.IRNode, torch.Tensor)
W: Union[ir.IRNode, torch.Tensor], trans_w: bool
) -> Union[ir.IRNode, torch.Tensor]:
def transpose_w(W: _T, trans_w: bool) -> _T:
""" """
Transpose W based on the trans_w flag. Transpose W based on the trans_w flag.
""" """
@ -317,9 +318,7 @@ def transpose_w(
return W return W
def expand_bias( def expand_bias(B: Optional[_T], X: _T) -> Optional[_T]:
B: Union[ir.IRNode, torch.Tensor, None], X: Union[ir.IRNode, torch.Tensor]
) -> Optional[Union[ir.IRNode, torch.Tensor]]:
""" """
Expand Bias to the same size of X. Expand Bias to the same size of X.
""" """
@ -336,7 +335,7 @@ def expand_bias(
return B return B
def prune_tensors(input_nodes: List[ir.TensorBox], new_input_nodes: List[ir.TensorBox]): def prune_tensors(input_nodes: List[ir.IRNode], new_input_nodes: List[ir.IRNode]):
""" """
Prune unused tensors from `V.graph` since the GEMM Template use new packed weight. Prune unused tensors from `V.graph` since the GEMM Template use new packed weight.
""" """
@ -798,6 +797,7 @@ class CppGemmTemplate(CppTemplate):
trans_w=False, trans_w=False,
input_indices=None, input_indices=None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None, epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None,
): ):
if input_indices is None: if input_indices is None:
input_indices = list(range(len(input_nodes))) input_indices = list(range(len(input_nodes)))
@ -1251,6 +1251,7 @@ class CppGemmTemplate(CppTemplate):
# --> zero or more out-of-template epilogues (`epilogue_nodes`) --> # --> zero or more out-of-template epilogues (`epilogue_nodes`) -->
# Y # Y
if epilogue_creators: if epilogue_creators:
assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer( gemm_output_buffer = ir.Buffer(
name=gemm_output_name, layout=template_buffer.layout name=gemm_output_name, layout=template_buffer.layout
@ -1276,14 +1277,17 @@ class CppGemmTemplate(CppTemplate):
name=buffer_name, layout=template_buffer.layout name=buffer_name, layout=template_buffer.layout
) )
assert isinstance(Y, (ir.Buffer, ir.ReinterpretView))
Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y
if epilogue_nodes: if epilogue_nodes:
if not template_buffer_has_other_users: if not template_buffer_has_other_users:
assert isinstance(template_buffer, ir.IRNode)
Y_aliases.add(template_buffer.get_name()) Y_aliases.add(template_buffer.get_name())
epilogues.extend(epilogue_nodes) epilogues.extend(epilogue_nodes)
assert Y.get_numel() == epilogues[-1].get_numel() assert Y.get_numel() == epilogues[-1].get_numel()
Y = cast(ir.Buffer, epilogues[-1]) Y = cast(ir.Buffer, epilogues[-1])
assert isinstance(template_buffer, ir.Buffer)
Y_2d, reindexers = gen_2d_view_of_epilogue_buf( Y_2d, reindexers = gen_2d_view_of_epilogue_buf(
Y, Y,
template_buffer, template_buffer,

View File

@ -0,0 +1,463 @@
import contextlib
import logging
from typing import Any, Callable, List, Optional, TypeVar
from unittest.mock import patch
import torch
import torch.utils
from torch.utils._ordered_set import OrderedSet
from ..._dynamo.utils import counters
from .. import config, ir
from ..kernel.mm_common import mm_args
from ..select_algorithm import ChoiceCaller, DataProcessorTemplateWrapper
from ..utils import parallel_num_threads
from ..virtualized import V
from .cpp import get_export_declaration
from .cpp_gemm_template import CppGemmTemplate, expand_bias, prune_tensors, transpose_w
from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import (
DTYPE_TO_CPP,
GemmBlocking,
get_gemm_template_output_and_compute_dtype,
)
log = logging.getLogger(__name__)
GEMM_TEMPLATE = r"""
{{template.header().getvalue()}}
{{micro_gemm.codegen_define(kernel)}}
extern "C" {{export_declaration}}
{{kernel.def_kernel(inputs=kernel_args, outputs=Y_list, aliases=aliases)}}
{
{{kernel.maybe_codegen_profile()}}
{{ template.codegen_blocks(
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOuts[0], config, L1_cache_size, L2_cache_size, X_list[0], W_list[0]
) }}
{%- if num_threads > 1 %}
#pragma omp parallel num_threads({{num_threads}})
{
{{ template.codegen_multi_threads_params()|indent(8, false) }}
{%- else %}
{
{{ template.codegen_single_thread_params(is_dynamic_M)|indent(8, false) }}
{%- endif %}
{{ micro_gemm.codegen_init(kernel) }}
{%- set acc_buf_name_list=[] %}
{%- set acc_buf_name_prefix = "local_acc_buf_" %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_buf_name = acc_buf_name_prefix + gemm_idx|string %}
{{ kernel.define_buffer(acc_buf_name, ["Mc_blocks*Mr", "Nc_blocks*Nr"], acc_buf_dtype) }}
{%- set acc_buf_name_list=acc_buf_name_list.append(acc_buf_name) %}
{%- endfor %}
for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
{{ template.codegen_m_loop_params()|indent(12, false) }}
for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
{{ template.codegen_n_loop_params()|indent(16, false) }}
{%- set acc_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_list = acc_list.append( kernel.local_buffers[acc_buf_name_list[gemm_idx]] ) %}
{{ kernel.reinit_buffer_if_null(acc_buf_name_list[gemm_idx]) }}
{%- endfor %}
for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
int64_t k_start = kc * Kr;
int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
{%- set tile_X_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_X_list = tile_X_list.append( kernel.slice_nd(X_list[gemm_idx], [("m_start", "m_end"), ("k_start", "k_end")]) ) %}
{%- endfor %}
for (int64_t nci = nc; nci < nc_block_end; nci++) {
{%- set tile_W_3d_list=[] %}
{%- set tile_W_list=[] %}
{%- set acc_slice_list=[] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set acc_slice_list = acc_slice_list.append(
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("(nci - nc)*Nr", "(nci - nc + 1)*Nr")])
) %}
{%- set tile_W_3d_list = tile_W_3d_list.append(
kernel.slice_nd(W_list[gemm_idx], [("nci", "nci + 1"), ("k_start", "k_end"), ()])
) %}
{%- endfor %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_W_list = tile_W_list.append(
kernel.view(tile_W_3d_list[gemm_idx], ["k_end - k_start", micro_gemm.register_blocking.block_n])
) %}
{%- endfor %}
if (kc == k_block_start) {
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{{ micro_gemm.codegen_call(
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=False
)|indent(28, false) }}
{%- endfor %}
} else {
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{{ micro_gemm.codegen_call(
kernel, tile_X_list[gemm_idx], tile_W_list[gemm_idx], acc_slice_list[gemm_idx], accum=True
)|indent(28, false) }}
{%- endfor %}
}
}
}
{
{%- set tile_acc_list = [] %}
{%- set tile_Y_list = [] %}
{%- for gemm_idx in range(0, gemm_grouped_num, 1) %}
{%- set tile_acc_list = tile_acc_list.append(
kernel.slice_nd(acc_list[gemm_idx], [("0", "m_end - m_start"), ("0", "n_end - n_start")])
) %}
{%- set tile_Y_list = tile_Y_list.append(
kernel.slice_nd(Y_2d_list[gemm_idx], [("m_start", "m_end"), ("n_start", "n_end")])
) %}
{%- endfor %}
{{ kernel.store_outputs(
tile_Y_list, tile_acc_list, GemmOuts, epilogue_nodes, offsets=("m_start", "n_start"), reindexers=reindexers
)|indent(20, false)
}}
}
}
}
{{ micro_gemm.codegen_finalize(kernel) }}
}
}
"""
def get_deduplicated_act(act_mapping: dict[int, ir.IRNode]) -> List[ir.IRNode]:
act_deduplicated = []
act_deduplicated_name: OrderedSet[str] = OrderedSet()
for act_idx in range(len(act_mapping.values())):
act = act_mapping[act_idx]
if act.get_name() not in act_deduplicated_name:
act_deduplicated.append(act)
act_deduplicated_name.add(act.get_name())
return act_deduplicated
class CppGroupedGemmTemplate(CppGemmTemplate):
def __init__(
self,
input_nodes: List[ir.IRNode],
layout: ir.Layout,
num_threads: int,
register_blocking: GemmBlocking,
beta: int = 1,
alpha: int = 1,
has_bias: bool = False,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None,
gemm_grouped_num: int = 1,
) -> None:
"""
Template for Group of GEMMs:
* Each GEMM has the same dimensions (m, n, k) and the same leading dimensions (lda, ldb, ldc)
for their A, B, and C matrices.
* Each GEMM has distinct or shared activations, has distinct weight, has unique bias or no bias, has distinct epilogues.
* In the current implementation, the outputs of all GEMMs are accumulated using pointwise epilogues.
This behavior can be extended in the future if needed.
"""
super().__init__(
input_nodes,
layout,
num_threads,
register_blocking,
beta,
alpha,
has_bias,
epilogue_creator,
)
self.act_mapping = act_mapping
self.gemm_grouped_num = gemm_grouped_num
self.output_node: List[ir.Buffer] = [
ir.Buffer(name="buf_out" + str(idx), layout=layout)
for idx in range(gemm_grouped_num)
]
@staticmethod
def _fake_get_dtype(fake_outs: List[ir.Buffer]) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
def get_dtype(name: str) -> torch.dtype:
for fake_out in fake_outs:
if name == fake_out.get_name():
return fake_out.get_dtype()
return _get_dtype_real(name)
return get_dtype
@classmethod
def add_choices(
cls,
choices: List[ChoiceCaller],
layout: ir.Layout,
input_nodes: List[ir.IRNode],
beta: int = 1,
alpha: int = 1,
has_bias: tuple[bool, ...] = (False, False),
trans_w: bool = False,
input_indices: Optional[List[int]] = None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
) -> DataProcessorTemplateWrapper:
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
gemm_grouped_num = len(has_bias)
assert act_mapping
act_deduplicated = get_deduplicated_act(act_mapping)
wgt_start_idx = len(act_deduplicated)
bias_start_idx = wgt_start_idx + gemm_grouped_num
input_indices = list(range(len(input_nodes)))
_T = TypeVar("_T", ir.IRNode, torch.Tensor)
_U = TypeVar("_U", ir.Layout, torch.Tensor)
def reorder_and_filter(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
assert input_indices is not None, "input_indices must be set"
return [inputs[idx] for idx in input_indices], layout_or_out
new_inputs, new_layout = reorder_and_filter(input_nodes, layout)
def maybe_to_dense(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_inputs = list(inputs)
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
if isinstance(inputs[idx], torch.Tensor):
W = inputs[idx]
assert isinstance(W, torch.Tensor), "W must be a torch.Tensor"
new_inputs[idx] = W.to_dense() if W.is_mkldnn else W
return new_inputs, layout_or_out
def normalize_shapes(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_inputs: List[_T] = list(inputs)
if not trans_w:
return new_inputs, layout_or_out
X = new_inputs[0]
for wgt_idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
new_input = new_inputs[wgt_idx]
new_inputs[wgt_idx] = transpose_w(new_input, trans_w)
for bias_idx in range(bias_start_idx, len(new_inputs)):
new_bias = expand_bias(new_inputs[bias_idx], X)
assert new_bias is not None
new_inputs[bias_idx] = new_bias
return new_inputs, layout_or_out
num_threads = parallel_num_threads()
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[wgt_start_idx])
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
new_inputs[0].get_dtype()
)
micro_gemm = create_micro_gemm(
"micro_gemm",
m,
n,
k,
input_dtype=new_inputs[0].get_dtype(),
input2_dtype=new_inputs[wgt_start_idx].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
alpha=alpha,
num_threads=num_threads,
)
assert micro_gemm is not None
_, block_n, _ = micro_gemm.register_blocking
new_size, padded_n = cls.get_padded_size(
n, block_n, k, should_block_weight=True
)
padding = padded_n - n
def pack_weight(
inputs: List[_T],
layout_or_out: _U,
) -> tuple[List[_T], _U]:
new_W_list = []
new_inputs = list(inputs)
W_list = new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num]
for W in W_list:
blocked_w = cls.block_weight(W, new_size, padding)
new_W_list.append(cls.pack_vnni_weight(blocked_w, micro_gemm, new_size))
new_inputs[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = new_W_list
return new_inputs, layout_or_out
def preprocessor(
inputs: List[_T],
layout: _U,
) -> tuple[List[_T], _U]:
return pack_weight(
*normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout)))
)
def postprocessor(output: _T) -> _T:
if isinstance(output, ir.TensorBox):
template_buffer = ir.InputsKernel.unwrap_storage_for_input(output)
assert isinstance(template_buffer, ir.CppTemplateBuffer)
new_input_nodes, _ = reorder_and_filter(input_nodes, layout)
W_nodes = new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
]
W_tensor = []
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
] = W_tensor # type: ignore[assignment]
new_input_nodes, _ = pack_weight(
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
)
# Prune unused tensors
prune_tensors(input_nodes, new_input_nodes)
for idx in range(wgt_start_idx, wgt_start_idx + gemm_grouped_num):
W_packed = new_input_nodes[idx]
assert isinstance(W_packed, torch.Tensor)
W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[
idx
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
return output
template = DataProcessorTemplateWrapper(
CppGroupedGemmTemplate,
preprocessor,
postprocessor,
input_nodes=input_nodes,
layout=layout,
num_threads=num_threads,
register_blocking=micro_gemm.register_blocking,
beta=beta,
alpha=alpha,
has_bias=has_bias,
epilogue_creator=epilogue_creator,
act_mapping=act_mapping,
gemm_grouped_num=gemm_grouped_num,
)
template.maybe_append_choice(choices)
return template
def render( # type: ignore[override,return,no-untyped-def]
self,
kernel: CppTemplateKernel,
template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
flag_template_buffer_has_other_users: Optional[bool] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
**kwargs,
) -> str:
assert self.act_mapping
act_deduplicated = get_deduplicated_act(self.act_mapping)
wgt_start_idx = len(act_deduplicated)
bias_start_idx = wgt_start_idx + self.gemm_grouped_num
X_list = list(self.act_mapping.values())
W_list = self.input_nodes[wgt_start_idx : wgt_start_idx + self.gemm_grouped_num]
inp_list = []
cur_idx = bias_start_idx
for inp_idx in range(self.gemm_grouped_num):
inp = None
if self.has_bias[inp_idx]:
inp = self.input_nodes[cur_idx]
cur_idx += 1
inp_list.append(inp)
Y_list = self.output_node
if template_buffer_node is not None:
W_list = template_buffer_node.inputs[
wgt_start_idx : wgt_start_idx + self.gemm_grouped_num
]
assert isinstance(template_buffer_node.outputs, List)
Y_list = template_buffer_node.outputs
counters["inductor"]["cpp_grouped_gemm_template"] += 1
template_buffer = Y_list[0]
fake_buffers: List[ir.Buffer] = []
Y_2d_list = Y_list
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
X_list[0].get_dtype()
)
micro_gemm = create_micro_gemm(
f"{kernel.kernel_name}_micro_gemm",
self.m,
self.n,
self.k,
input_dtype=X_list[0].get_dtype(),
input2_dtype=W_list[0].get_dtype(),
output_dtype=output_dtype,
compute_dtype=compute_dtype,
alpha=self.alpha,
num_threads=self.num_threads,
)
assert micro_gemm is not None
assert self.register_blocking == micro_gemm.register_blocking
self.log_blockings()
if isinstance(micro_gemm, CppMicroGemmAMX):
counters["inductor"]["cpp_micro_gemm_amx_counter"] += 1
L1_cache_size = torch._C._cpu._L1d_cache_size() # per core cache size in Bytes
assert L1_cache_size > 0, f"Expect L1_cache_size > 0 but got {L1_cache_size}"
L2_cache_size = torch._C._cpu._L2_cache_size() # per core cache size in Bytes
assert L2_cache_size > 0, f"Expect L2_cache_size > 0 but got {L2_cache_size}"
epilogues: List[ir.IRNode] = []
reindexers: List[Optional[Callable[[List[Any]], List[Any]]]] = []
gemm_output_buffers: list[ir.Buffer] = []
for out_buf_idx in range(self.gemm_grouped_num):
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" + str(
out_buf_idx
)
gemm_output_buffers.append(
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
)
assert (
not self.epilogue_creator and not epilogue_nodes
), "Epilogue fusion is not implemented yet in Grouped GEMM Template"
kernel_args: dict[str, Optional[ir.IRNode]] = {}
for x_idx in range(wgt_start_idx):
kernel_args["X" + str(x_idx)] = act_deduplicated[x_idx]
for w_idx in range(self.gemm_grouped_num):
kernel_args["W" + str(w_idx)] = W_list[w_idx]
for inp_idx in range(self.gemm_grouped_num):
kernel_args["inp" + str(inp_idx)] = inp_list[inp_idx]
options = dict(
N=self.n,
K=self.k,
PADDED_N=self.padded_n,
aliases={},
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
micro_gemm=micro_gemm,
is_dynamic_M=self.is_dynamic_M,
template=self,
kernel=kernel,
export_declaration=get_export_declaration(),
acc_buf_dtype=torch.float,
DTYPE_TO_CPP=DTYPE_TO_CPP,
L1_cache_size=L1_cache_size,
L2_cache_size=L2_cache_size,
config=config,
epilogue_nodes=epilogues,
GemmOuts=gemm_output_buffers,
reindexers=reindexers,
kernel_args=kernel_args,
X_list=X_list,
W_list=W_list,
gemm_grouped_num=self.gemm_grouped_num,
Y_list={"Y" + str(idx): Y for idx, Y in enumerate(Y_list)},
Y_2d_list=Y_2d_list,
)
with contextlib.ExitStack() as stack:
stack.enter_context(
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_buffers))
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)

View File

@ -4,7 +4,7 @@ import functools
import itertools import itertools
import logging import logging
import sys import sys
from typing import Callable, List, Optional from typing import Callable, Iterable, List, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import sympy import sympy
@ -33,7 +33,9 @@ class CppTemplate(KernelTemplate):
) -> None: ) -> None:
super().__init__(name) super().__init__(name)
self.input_nodes = input_nodes self.input_nodes = input_nodes
self.output_node: ir.Buffer = ir.Buffer(name="buf_out", layout=layout) self.output_node: Union[ir.Buffer, List[ir.Buffer]] = ir.Buffer(
name="buf_out", layout=layout
)
self.layout = layout self.layout = layout
self.num_threads = num_threads self.num_threads = num_threads
self.epilogue_creator = epilogue_creator self.epilogue_creator = epilogue_creator
@ -57,7 +59,10 @@ class CppTemplate(KernelTemplate):
expected_args = list( expected_args = list(
unique(input_node.get_name() for input_node in self.input_nodes) unique(input_node.get_name() for input_node in self.input_nodes)
) )
expected_args.extend([self.output_node.get_name()]) if isinstance(self.output_node, Iterable):
expected_args.extend([node.get_name() for node in self.output_node])
else:
expected_args.extend([self.output_node.get_name()])
assert list(call_args)[: len(expected_args)] == expected_args, ( assert list(call_args)[: len(expected_args)] == expected_args, (
call_args, call_args,
expected_args, expected_args,
@ -102,7 +107,9 @@ class CppTemplate(KernelTemplate):
kernel_hash_name, kernel_hash_name,
self.name, self.name,
self.input_nodes, self.input_nodes,
self.output_node.get_layout(), self.output_node[0].get_layout()
if isinstance(self.output_node, Iterable)
else self.output_node.get_layout(),
make_kernel_render, make_kernel_render,
bmreq, bmreq,
self, self,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import itertools import itertools
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Union
import sympy import sympy
from sympy.parsing.sympy_parser import parse_expr from sympy.parsing.sympy_parser import parse_expr
@ -278,6 +278,82 @@ class CppTemplateKernel(CppKernel):
kernel_group.finalize_kernel(cpp_kernel_proxy, []) kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue() return kernel_group.loops_code.getvalue()
def store_grouped_gemm_pointwise_nodes(
self,
dst: tuple[ir.Buffer],
nodes: List[List[ir.IRNode]],
offsets: Optional[List[sympy.Expr]] = None,
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
) -> str:
ref_dst = dst[0]
var_sizes = (tuple(ref_dst.get_size()), ())
var_ranges = {
sympy_index_symbol_with_prefix(SymT.INDEX, i): sz
for i, sz in enumerate(var_sizes[0])
}
if not offsets:
offsets = [sympy.S.Zero] * len(var_sizes[0])
if not reindexers:
reindexers = [None] * len(nodes)
assert len(offsets) == len(var_sizes[0])
output_index = ref_dst.get_layout().make_indexer()([*var_ranges.keys()])
kernel_group = KernelGroup()
kernel_group.args = self.args
cpp_kernel_proxy = CppKernelProxy(kernel_group)
bodies = []
var_sizes_list = []
assert isinstance(nodes[0], Iterable)
grouped_gemm_number = len(nodes)
epilogue_nodes = nodes[0]
assert isinstance(epilogue_nodes, Iterable)
for i, _ in enumerate(epilogue_nodes):
output_names = []
gemm_nodes = []
for gemm_idx in range(grouped_gemm_number):
single_gemm_nodes = nodes[gemm_idx]
assert isinstance(dst, Iterable)
single_gemm_dst = dst[gemm_idx]
assert isinstance(single_gemm_nodes, Iterable)
assert isinstance(single_gemm_dst, ir.IRNode)
gemm_nodes.append(single_gemm_nodes[i])
output_names.append(
single_gemm_nodes[i].get_name()
if i < len(single_gemm_nodes) - 1
else single_gemm_dst.get_name()
)
_node = gemm_nodes[gemm_idx]
gemm_nodes[gemm_idx] = (
_node.data if isinstance(_node, ir.ComputedBuffer) else _node
)
def fn(*args):
assert len(args) == 2
assert len(args[0]) == len(var_sizes[0])
assert len(args[1]) == 0
new_args = [arg + offset for arg, offset in zip(args[0], offsets)] # type: ignore[arg-type]
if reindexers[i] is not None:
new_args = reindexers[i](new_args) # type: ignore[misc]
for gemm_idx in range(grouped_gemm_number):
V.ops.store(
output_names[gemm_idx],
output_index,
gemm_nodes[gemm_idx].make_loader()(new_args).value,
)
body = LoopBody(
fn,
(list(var_ranges.keys()), ()),
var_ranges,
list(var_ranges.keys()),
tuple(),
)
bodies.append(body)
var_sizes_list.append(var_sizes)
cpp_kernel_proxy.codegen_loop_bodies(bodies, var_sizes_list)
kernel_group.finalize_kernel(cpp_kernel_proxy, [])
return kernel_group.loops_code.getvalue()
def store_output( def store_output(
self, self,
dst: ir.Buffer, dst: ir.Buffer,
@ -335,6 +411,43 @@ class CppTemplateKernel(CppKernel):
assert dst.layout == src.layout, f"{dst=}, {src=}" assert dst.layout == src.layout, f"{dst=}, {src=}"
return "" return ""
def store_outputs(
self,
dst: tuple[ir.Buffer],
src: tuple[ir.IRNode],
orig_src: Optional[tuple[ir.IRNode]] = None,
epilogue_nodes: Optional[List[ir.IRNode]] = None,
offsets: Optional[List[Any]] = None,
reindexers: Optional[List[Optional[Callable[[List[Any]], List[Any]]]]] = None,
):
# Grouped GEMM may have multi outputs to be localized
assert isinstance(src, Iterable)
assert isinstance(dst, Iterable)
assert all(_dst.get_size() == _src.get_size() for _src, _dst in zip(src, dst))
if offsets:
offsets = parse_expr_with_index_symbols(offsets)
if epilogue_nodes:
assert (
not epilogue_nodes
), "epilogue_nodes not supported for Grouped GEMM yet"
else:
if dst[0].get_name() != src[0].get_name():
copy_list = []
with LocalBufferContext(self.args) as scope:
for _src, _dst in zip(src, dst):
copy_list.append([L.copy(_dst, _src).data.data])
scope.add_local_buffer(_src)
return self.store_grouped_gemm_pointwise_nodes(dst, copy_list)
else:
assert all(
_src.get_name() == _dst.get_name() for _src, _dst in zip(src, dst)
)
assert all(
_src.get_layout() == _dst.get_layout()
for _src, _dst in zip(src, dst)
)
return ""
class CppTemplateCaller(ir.ChoiceCaller): class CppTemplateCaller(ir.ChoiceCaller):
""" """

View File

@ -2193,9 +2193,12 @@ class PythonWrapperCodegen(CodeGen):
): ):
return return
self.allocated.add(name) self.allocated.add(name)
if isinstance( if (
buffer.get_defining_op(), isinstance(
(ir.ExternKernelAlloc, ir.MultiOutput), buffer.get_defining_op(),
(ir.ExternKernelAlloc, ir.MultiOutput),
)
and not buffer.should_allocate()
): ):
return return

View File

@ -884,6 +884,9 @@ class cpp:
os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1" os.environ.get("TORCHINDUCTOR_CPP_ENABLE_TILING_HEURISTIC", "1") == "1"
) )
# Enable the Grouped GEMM Fusion
enable_grouped_gemm_template = False
# Maximal allowed number of slices on K-dim for a GEMM kernel. This controls # Maximal allowed number of slices on K-dim for a GEMM kernel. This controls
# the maximal parallelism of K-slicing. Since K-slicing requires extra thread # the maximal parallelism of K-slicing. Since K-slicing requires extra thread
# synchronization and buffers, the maximal number of slices is limited to # synchronization and buffers, the maximal number of slices is limited to

View File

@ -38,6 +38,75 @@ if torch._C._has_mkldnn:
_linear_args = [Arg() for _ in range(6)] _linear_args = [Arg() for _ in range(6)]
_conv_transpose_args = [Arg() for _ in range(11)] _conv_transpose_args = [Arg() for _ in range(11)]
def _is_valid_grouped_gemm_fusion(computation_nodes):
"""
Here we check:
1. More than 1 GEMM nodes has been found.
2. All the GEMM nodes share the same activation.
3. All the GEMM nodes have same weight size but different wgt node.
"""
computation_op = mkldnn._linear_pointwise.default
act = computation_nodes[0].args[0]
wgt = computation_nodes[0].args[1]
wgt_size = wgt.meta.get("val").size() # type: ignore[union-attr]
return len(computation_nodes) >= 2 and all(
(
node.target == computation_op
and node.args[0] == act
and (node.args[1].meta.get("val").size() == wgt_size)
and (node.args[1] != wgt or gemm_idx == 0)
and not node.args[2] # <TODO> support bias through epilogue fusion
)
for gemm_idx, node in enumerate(computation_nodes)
)
def grouped_gemm_pass(graph: torch.fx.Graph):
"""
Group GEMM has multi output nodes which is compilicated to define a Pattern.
Use below way to connect the pattern to the lowering.
TODO: Use MultiOutputPattern, current limitation is the pattern requires
fixed number of output nodes. Extend to support Group GEMM for pattern matcher.
"""
computation_op = mkldnn._linear_pointwise.default
from ..mkldnn_lowerings import grouped_gemm_lowering
for node in graph.find_nodes(op="call_function", target=computation_op):
if (
not node._erased
and isinstance(node.meta.get("val"), torch.Tensor)
and node.meta["val"].device.type == "cpu"
):
act = node.args[0]
users = list(act.users)
if _is_valid_grouped_gemm_fusion(users):
with graph.inserting_before(node):
grouped_gemm_node = graph.create_node(
"call_function",
grouped_gemm_lowering,
(
act,
[user.args[1] for user in users],
[None for _ in users],
),
)
grouped_gemm_node.meta["val"] = [
user.meta["val"] for user in users
]
with graph.inserting_after(grouped_gemm_node):
for gemm_idx, user in enumerate(users):
assert user.target == computation_op
get_item = graph.create_node(
"call_function",
operator.getitem,
(
grouped_gemm_node,
gemm_idx,
),
)
user.replace_all_uses_with(get_item)
graph.erase_node(user)
return
def _conv_call(users=1): def _conv_call(users=1):
return CallFunction( return CallFunction(
mkldnn._convolution_pointwise.default, *_conv_args, _users=users mkldnn._convolution_pointwise.default, *_conv_args, _users=users

View File

@ -102,6 +102,16 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
post_grad_custom_pre_pass post_grad_custom_pre_pass
) )
if (
config.cpp.enable_grouped_gemm_template
and config.max_autotune
and "CPP" in config.max_autotune_gemm_backends
and torch._C._has_mkldnn
):
from .mkldnn_fusion import grouped_gemm_pass
grouped_gemm_pass(gm.graph)
if config.pattern_matcher: if config.pattern_matcher:
lazy_init() lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)

View File

@ -4496,6 +4496,18 @@ class CppTemplateBuffer(TemplateBuffer):
super().__init__(layout, inputs, make_kernel_render) super().__init__(layout, inputs, make_kernel_render)
self.template = template self.template = template
self.choice = choice self.choice = choice
self.outputs: Optional[List[Buffer]] = None
def get_layout(self) -> Layout:
if isinstance(self.layout, MultiOutputLayout):
assert isinstance(self.outputs, Iterable)
first_output = self.outputs[0]
assert isinstance(first_output, Buffer)
layout = first_output.layout
assert isinstance(layout, Layout)
return layout
else:
return super().get_layout()
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
@ -6832,6 +6844,10 @@ class MultiOutput(ExternKernel):
return self.inputs[0].get_unbacked_symbol_uses() return self.inputs[0].get_unbacked_symbol_uses()
def should_allocate(self) -> bool: def should_allocate(self) -> bool:
if len(self.inputs) == 1 and (
isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
):
return True
return False return False
def get_inputs_that_alias_output(self) -> Sequence[str]: def get_inputs_that_alias_output(self) -> Sequence[str]:

View File

@ -8,6 +8,7 @@ from torch._inductor.kernel.mm_common import mm_args
from . import ir from . import ir
from .codegen.cpp_gemm_template import CppGemmTemplate from .codegen.cpp_gemm_template import CppGemmTemplate
from .codegen.cpp_grouped_gemm_template import CppGroupedGemmTemplate
from .codegen.cpp_utils import create_epilogue_with_attr from .codegen.cpp_utils import create_epilogue_with_attr
from .ir import TensorBox from .ir import TensorBox
from .lowering import ( from .lowering import (
@ -28,6 +29,73 @@ from .utils import use_aten_gemm_kernels, use_cpp_gemm_template, use_max_autotun
from .virtualized import ops, V from .virtualized import ops, V
def grouped_gemm_lowering(
x: TensorBox,
w: List[TensorBox],
b: List[TensorBox],
attr=None,
scalars=None,
algorithm=None,
layout=None,
):
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
x = view(x, [-1, x_size[-1]])
num_gemm = len(w)
assert use_max_autotune()
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
choices: List[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
kwargs = dict(
has_bias=[bias is not None for bias in b],
trans_w=True,
epilogue_creator=None,
act_mapping={num: x for num in range(num_gemm)},
)
input_nodes = [x, *w]
input_nodes.extend([bias for bias in b if bias is not None])
CppGroupedGemmTemplate.add_choices(
choices,
layout,
input_nodes,
**kwargs, # type: ignore[arg-type]
)
assert len(choices) != 0
result = autotune_select_algorithm(
"grouped_gemm",
choices,
input_nodes,
layout,
)
template_buf = result.data.data
return_bufs = [
ir.MultiOutput(layout, template_buf, [(list, gemm_idx)])
for gemm_idx in range(num_gemm)
]
template_buf.layout = ir.MultiOutputLayout(device=input_nodes[0].get_device())
template_buf.outputs = return_bufs
return_tensors = [
ir.TensorBox.create(return_bufs[gemm_idx]) for gemm_idx in range(num_gemm)
]
if len(x_size) > 2:
for gemm_idx in range(num_gemm):
return_tensors[gemm_idx] = view(
return_tensors[gemm_idx],
(*x_size[:-1], return_tensors[gemm_idx].get_size()[-1]),
)
return return_tensors
grouped_gemm_lowering._inductor_lowering_function = True # type: ignore[attr-defined]
def register_onednn_fusion_ops(): def register_onednn_fusion_ops():
if torch._C._has_mkldnn: if torch._C._has_mkldnn:
from . import mkldnn_ir from . import mkldnn_ir

View File

@ -3607,6 +3607,15 @@ class Scheduler:
# the current kernel from where 'allocate' retrieve those decisions. # the current kernel from where 'allocate' retrieve those decisions.
# We have to make sure there is a non-NULL kernel handler to store # We have to make sure there is a non-NULL kernel handler to store
# those inplace update decisions. # those inplace update decisions.
if (
isinstance(scheduler_node.node, ir.MultiOutput)
and len(scheduler_node.node.inputs) == 1
and isinstance(scheduler_node.node.inputs[0], ir.CppTemplateBuffer)
):
# <TODO> Remove this code after Fuse MultiOutput and CppTemplateBuffer
return
counters["inductor"]["extern_calls"] += 1 counters["inductor"]["extern_calls"] += 1
with V.set_kernel_handler(Kernel(increase_kernel_count=False)): with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
scheduler_node.decide_inplace_update() scheduler_node.decide_inplace_update()