mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967)
**Summary** Currently, the Inductor CPP backend [generated code](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-wo-local-buffer-py) for `Softmax` with BF16 data type is significantly slower than the [ATen Implementation](9a2beb862d/aten/src/ATen/native/cpu/SoftMaxKernel.cpp (L149)). Upon comparing the generated code with ATen, the performance bottleneck appears to be related to the usage of [local buffer in ATen](9a2beb862d/aten/src/ATen/native/cpu/SoftMaxKernel.cpp (L159-L160)). In the current implementation, the Inductor uses the output buffer of Kernel Group Args to store and load temporary result (such as `exp`), since this buffer is corresponding to a `SchedulerNode`. Each thread accesses a portion of this output buffer via indexing. However, since this buffer (take this `exp` as example) is only utilized internally within decomposed `softmax`, this buffer can be replaced with a thread-local buffer similar to ATen's approach. In this PR, we have introduced the optimizations of `LocalBuffer`. Following this enhancement, the [new generated Inductor code with local buffer](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-w-local-buffer-py) for BF16 `Softmax` demonstrates significantly improved performance. Running the benchmark [here](https://gist.github.com/leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a) to test this BF16 `Softmax` case on an 8480 Xeon server shows similar performance between the Inductor CPP Backend and the ATen implementation. **TestPlan** ``` python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_local_buffer_in_outer_loop_fusion ``` **Next Step** - [ ] Support more than one Local Buffer/Global Buffer Pull Request resolved: https://github.com/pytorch/pytorch/pull/126967 Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
parent
a3ce9eddd6
commit
98929ceae3
|
|
@ -2556,6 +2556,7 @@ class CPUReproTests(TestCase):
|
|||
self.common(fn, (x,))
|
||||
assert metrics.generated_cpp_vec_kernel_count == 0
|
||||
|
||||
@config.patch(fx_graph_cache=False)
|
||||
def test_outer_loop_fusion(self):
|
||||
def fn(x):
|
||||
max = torch.amax(x, dim=-1, keepdim=True)
|
||||
|
|
@ -2567,8 +2568,47 @@ class CPUReproTests(TestCase):
|
|||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1
|
||||
assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2
|
||||
self.assertEqual(
|
||||
len(metrics.cpp_outer_loop_fused_inner_counts),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
|
||||
2,
|
||||
)
|
||||
|
||||
@config.patch(fx_graph_cache=False)
|
||||
def test_local_buffer_in_outer_loop_fusion(self):
|
||||
def fn(x):
|
||||
max = torch.nn.functional.softmax(x, dim=-1)
|
||||
return x - max
|
||||
|
||||
x = torch.randn(4, 12, 1023, 1022)
|
||||
|
||||
with config.patch({"cpp.simdlen": None}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
self.common(fn, (x,))
|
||||
self.assertEqual(
|
||||
len(metrics.cpp_outer_loop_fused_inner_counts),
|
||||
1,
|
||||
)
|
||||
self.assertEqual(
|
||||
metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number,
|
||||
3,
|
||||
)
|
||||
self.assertEqual(
|
||||
metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number,
|
||||
1,
|
||||
)
|
||||
# Check the number of global buffer allocation
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
_, code = run_and_get_cpp_code(
|
||||
torch._dynamo.optimize("inductor")(fn),
|
||||
x,
|
||||
)
|
||||
self.assertEqual(code.count("empty_strided_cpu("), 3)
|
||||
|
||||
def test_argmin(self):
|
||||
def fn(x):
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import logging
|
|||
import math
|
||||
import re
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from copy import copy, deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
|
@ -69,6 +70,7 @@ from .cpp_utils import (
|
|||
cexpr_index,
|
||||
DTYPE_TO_CPP,
|
||||
INDEX_TYPE,
|
||||
LocalBufferContext,
|
||||
unify_mask_base_type,
|
||||
value_to_cpp,
|
||||
)
|
||||
|
|
@ -435,8 +437,6 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
|
|||
loop_nest_list: List[LoopNestWithSplit] = [
|
||||
kernel.loop_nest for kernel in cpp_kernel_proxy_list
|
||||
]
|
||||
metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list))
|
||||
|
||||
kernel_group = cpp_kernel_proxy_list[0].kernel_group
|
||||
|
||||
def _merge_outer_fusion_loop_levels(
|
||||
|
|
@ -1915,7 +1915,10 @@ class CppKernel(Kernel):
|
|||
threads = parallel_num_threads()
|
||||
assert self.call_ranges is not None
|
||||
kernels = loop_nest.get_kernels()
|
||||
if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels):
|
||||
has_outer_loop_kernel = any(
|
||||
isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels
|
||||
)
|
||||
if has_outer_loop_kernel:
|
||||
assert len(kernels) == 1
|
||||
assert isinstance(kernels[0], OuterLoopFusedKernel)
|
||||
par_depth = kernels[0].decide_parallel_depth(
|
||||
|
|
@ -2045,6 +2048,31 @@ class CppKernel(Kernel):
|
|||
|
||||
stack.enter_context(code.indent())
|
||||
if loop_nest.root:
|
||||
if (
|
||||
has_outer_loop_kernel
|
||||
and isinstance(V.local_buffer_context, LocalBufferContext)
|
||||
and V.local_buffer_context.local_buffers
|
||||
):
|
||||
# Allocate local buffer
|
||||
local_buffers = V.local_buffer_context.local_buffers
|
||||
assert len(local_buffers.items()) == 1
|
||||
local_buffer = next(iter(local_buffers.items()))[1]
|
||||
# For dynamic size, rename s to ks
|
||||
local_buf_size = sympy_product(
|
||||
[
|
||||
self.rename_indexing(size_val)
|
||||
for size_val in local_buffer.get_layout().size
|
||||
]
|
||||
)
|
||||
local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype]
|
||||
allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})"
|
||||
code.splice(
|
||||
f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};"
|
||||
)
|
||||
local_buffer_name = local_buffer.get_name()
|
||||
code.splice(
|
||||
f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();"
|
||||
)
|
||||
gen_loops(loop_nest.root)
|
||||
else:
|
||||
gen_kernel(loop_nest.kernel)
|
||||
|
|
@ -3500,6 +3528,18 @@ class CppKernelProxy(CppKernel):
|
|||
return node.codegen(index_vars)
|
||||
|
||||
fn_list = [functools.partial(fn, node) for node in nodes]
|
||||
|
||||
if (
|
||||
isinstance(V.local_buffer_context, LocalBufferContext)
|
||||
and V.local_buffer_context.local_buffers
|
||||
):
|
||||
fn_list = [
|
||||
V.local_buffer_context.localize_function(
|
||||
fn,
|
||||
)
|
||||
for fn in fn_list
|
||||
]
|
||||
|
||||
var_sizes_list = [node.group[1] for node in nodes]
|
||||
self.codegen_functions(fn_list, var_sizes_list, vec_dtype)
|
||||
|
||||
|
|
@ -3807,6 +3847,159 @@ class CppScheduling(BaseScheduling):
|
|||
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
|
||||
) or self.can_fuse_vertical_outer_loop(node1, node2)
|
||||
|
||||
def codegen_outer_loop_node(
|
||||
self,
|
||||
node: OuterLoopFusedSchedulerNode,
|
||||
):
|
||||
"""
|
||||
Generate the code for the outer loop fused scheduler node.
|
||||
1. Codegen with fused outer loop: depends on the analysis of
|
||||
the outer loop fused scheduler node, with or without the local buffer.
|
||||
2. If failed, fallback to standard codegen.
|
||||
"""
|
||||
kernel_group = self.kernel_group
|
||||
generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count
|
||||
cpp_kernel_proxy_list: List[CppKernelProxy] = []
|
||||
nodes_list: List[List[SchedulerNode]] = []
|
||||
assert isinstance(node, OuterLoopFusedSchedulerNode)
|
||||
|
||||
def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode):
|
||||
"""
|
||||
Codegen code with fused outer loop and local Buffer.
|
||||
"""
|
||||
assert isinstance(node, OuterLoopFusedSchedulerNode)
|
||||
cpp_kernel_proxy_list.clear()
|
||||
nodes_list.clear()
|
||||
|
||||
def get_call_ranges(node: BaseSchedulerNode):
|
||||
assert isinstance(node, (SchedulerNode, FusedSchedulerNode))
|
||||
nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
|
||||
_, (group, reduction_group) = max(
|
||||
nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
return call_ranges
|
||||
|
||||
LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"])
|
||||
local_buffers: List[LocalBuffer] = []
|
||||
if all(
|
||||
len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1
|
||||
for _node in node.get_outer_nodes()
|
||||
):
|
||||
# Ref to the typical case of local buffer
|
||||
# in https://github.com/pytorch/pytorch/blob/
|
||||
# 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159
|
||||
# where the buffer is with size of last dim and contiguous.
|
||||
# Only support this typical case at first.
|
||||
for scheduler_node in node.get_nodes():
|
||||
# all users inside same OuterLoopFusedSchedulerNode
|
||||
if not scheduler_node.is_reduction() and all(
|
||||
user.node in node.get_nodes() for user in scheduler_node.users
|
||||
):
|
||||
global_buffer = scheduler_node.node
|
||||
assert isinstance(global_buffer, ir.ComputedBuffer)
|
||||
global_buffer_layout = global_buffer.get_layout()
|
||||
size_offset = node.outer_loop_fusion_depth - len(
|
||||
get_call_ranges(scheduler_node)
|
||||
)
|
||||
|
||||
def is_all_write_read_contiguous(scheduler_node):
|
||||
contiguous_index_expr = 0
|
||||
stride = 1
|
||||
for var, range in reversed(
|
||||
scheduler_node._body.var_ranges.items()
|
||||
):
|
||||
contiguous_index_expr += stride * var
|
||||
stride *= range
|
||||
write_index_expr = scheduler_node._body.writes_name2expr[
|
||||
scheduler_node.get_name()
|
||||
]
|
||||
|
||||
def is_contiguous_index(x):
|
||||
return x == contiguous_index_expr
|
||||
|
||||
return is_contiguous_index(write_index_expr) and all(
|
||||
is_contiguous_index(
|
||||
user.node._body.reads_name2expr[
|
||||
scheduler_node.get_name()
|
||||
],
|
||||
)
|
||||
for user in scheduler_node.users
|
||||
)
|
||||
|
||||
if not (
|
||||
global_buffer_layout.is_contiguous()
|
||||
and not scheduler_node.is_reduction()
|
||||
and is_all_write_read_contiguous(scheduler_node)
|
||||
):
|
||||
continue
|
||||
# Local Buffer is a view of global buffer
|
||||
local_buffer_layout = ir.FixedLayout(
|
||||
global_buffer_layout.device,
|
||||
global_buffer_layout.dtype,
|
||||
global_buffer_layout.size[size_offset:],
|
||||
global_buffer_layout.stride[size_offset:],
|
||||
)
|
||||
local_buffers.append(
|
||||
LocalBuffer(
|
||||
local_buf=ir.Buffer(
|
||||
"local_buffer_data", local_buffer_layout
|
||||
),
|
||||
global_buf=global_buffer,
|
||||
)
|
||||
)
|
||||
# At most 1 node with local buf for each OuterLoopFusedSchedulerNode
|
||||
break
|
||||
assert len(local_buffers) in [0, 1]
|
||||
|
||||
with LocalBufferContext(kernel_group.args) as scope:
|
||||
if len(local_buffers) > 0:
|
||||
scope.add_local_buffer(
|
||||
local_buffers[0].local_buf, local_buffers[0].global_buf
|
||||
)
|
||||
for _node in node.get_outer_nodes():
|
||||
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
|
||||
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||
cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type]
|
||||
cpp_kernel_proxy_list.append(cpp_kernel_proxy)
|
||||
nodes_list.append(_node.get_nodes()) # type: ignore[arg-type]
|
||||
|
||||
if not node.check_outer_fusion_loop_level_attr(
|
||||
cpp_kernel_proxy_list, node.outer_loop_fusion_depth
|
||||
):
|
||||
return False
|
||||
metrics.cpp_outer_loop_fused_inner_counts.append(
|
||||
metrics.CppOuterLoopFusedCount(
|
||||
len(cpp_kernel_proxy_list),
|
||||
local_buffer_number=len(local_buffers),
|
||||
)
|
||||
)
|
||||
outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
|
||||
cpp_kernel_proxy_list,
|
||||
)
|
||||
kernel_group.finalize_kernel(
|
||||
outer_fusion_cpp_kernel_proxy,
|
||||
[_node for _nodes in nodes_list for _node in _nodes],
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
if not try_outer_loop_fusion_with_local_buf(node):
|
||||
# Reset generated_cpp_vec_kernel_count to codegen again
|
||||
metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count
|
||||
cpp_kernel_proxy_list.clear()
|
||||
nodes_list.clear()
|
||||
# Similar as comment in
|
||||
# https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272
|
||||
# Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args.
|
||||
with torch._inductor.config.patch(inplace_buffers=False):
|
||||
for _node in node.get_outer_nodes():
|
||||
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
|
||||
_nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
|
||||
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||
cpp_kernel_proxy.codegen_nodes(_nodes)
|
||||
kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes)
|
||||
|
||||
def codegen_node(
|
||||
self,
|
||||
node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode],
|
||||
|
|
@ -3817,38 +4010,7 @@ class CppScheduling(BaseScheduling):
|
|||
kernel_group = self.kernel_group
|
||||
|
||||
if isinstance(node, OuterLoopFusedSchedulerNode):
|
||||
cpp_kernel_proxy_list: List[CppKernelProxy] = []
|
||||
nodes_list: List[List[SchedulerNode]] = []
|
||||
|
||||
for _node in node.get_outer_nodes():
|
||||
assert isinstance(_node, (FusedSchedulerNode, SchedulerNode))
|
||||
_nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment]
|
||||
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||
cpp_kernel_proxy.codegen_nodes(_nodes)
|
||||
|
||||
cpp_kernel_proxy_list.append(cpp_kernel_proxy)
|
||||
nodes_list.append(_nodes)
|
||||
|
||||
# Note that, in the future, when every kernel can be vectorized,
|
||||
# the function select_tiling will be much easier, and we'll be able to lift
|
||||
# check_outer_fusion_loop_level_attr to the fusion phase,
|
||||
# avoiding grouping kernels at fusion time that "look like we'll be able to fuse them"
|
||||
# but then we actually won't.
|
||||
if node.check_outer_fusion_loop_level_attr(
|
||||
cpp_kernel_proxy_list, node.outer_loop_fusion_depth
|
||||
):
|
||||
# Merge the cpp_kernel_proxy_list into cpp_kernel_proxy
|
||||
outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels(
|
||||
cpp_kernel_proxy_list,
|
||||
)
|
||||
kernel_group.finalize_kernel(
|
||||
outer_fusion_cpp_kernel_proxy,
|
||||
[_node for _nodes in nodes_list for _node in _nodes],
|
||||
)
|
||||
else:
|
||||
# Fall back to standard loop codegen
|
||||
for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list):
|
||||
kernel_group.finalize_kernel(_kernel_proxy, _nodes)
|
||||
self.codegen_outer_loop_node(node)
|
||||
else:
|
||||
nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment]
|
||||
cpp_kernel_proxy = CppKernelProxy(kernel_group)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from ..select_algorithm import PartialRender
|
|||
from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix
|
||||
from ..virtualized import V
|
||||
from .cpp import CppKernel, CppKernelProxy, KernelGroup
|
||||
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope
|
||||
from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext
|
||||
|
||||
|
||||
def parse_expr_with_index_symbols(expr):
|
||||
|
|
@ -270,13 +270,11 @@ class CppTemplateKernel(CppKernel):
|
|||
if offsets:
|
||||
offsets = parse_expr_with_index_symbols(offsets)
|
||||
if epilogue_nodes:
|
||||
with LocalBufferScope(self) as scope:
|
||||
with LocalBufferContext(self.args) as scope:
|
||||
assert orig_src is not None
|
||||
if orig_src.get_name() != src.get_name():
|
||||
scope.add_local_buffer(src)
|
||||
epilogue_nodes = scope.localize_buffer(
|
||||
orig_src, src, epilogue_nodes
|
||||
)
|
||||
scope.add_local_buffer(src, orig_src)
|
||||
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
||||
return self.store_pointwise_nodes(
|
||||
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
|
||||
)
|
||||
|
|
@ -284,7 +282,7 @@ class CppTemplateKernel(CppKernel):
|
|||
if dst.get_name() != src.get_name():
|
||||
# src is local
|
||||
copy = L.copy(dst, src).data.data
|
||||
with LocalBufferScope(self) as scope:
|
||||
with LocalBufferContext(self.args) as scope:
|
||||
scope.add_local_buffer(src)
|
||||
return self.store_pointwise_nodes(dst, [copy])
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import copy
|
|||
import math
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
|
@ -12,11 +12,10 @@ import sympy
|
|||
import torch
|
||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||
from .. import ir
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix
|
||||
from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from ..virtualized import V
|
||||
|
||||
from .common import CSEVariable, ExprPrinter, Kernel
|
||||
|
||||
from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs
|
||||
|
||||
DTYPE_TO_CPP = {
|
||||
torch.float32: "float",
|
||||
|
|
@ -304,7 +303,88 @@ def value_to_cpp(value, cpp_type):
|
|||
return f"static_cast<{cpp_type}>({repr(value)})"
|
||||
|
||||
|
||||
class LocalBufferScope:
|
||||
def rewrite_index_for_function(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
):
|
||||
# Local buffer at the inner dimensions
|
||||
snode = V.graph.scheduler.name_to_node.get(
|
||||
localize_buffer_handler.global_buf.get_name()
|
||||
)
|
||||
assert snode is not None
|
||||
scheduler_nodes = snode.get_nodes()
|
||||
_, (group, reduction_group) = max(
|
||||
scheduler_nodes, key=lambda x: int(x.is_reduction())
|
||||
).group
|
||||
call_ranges = tuple(group) + tuple(reduction_group)
|
||||
indices_to_keep = [
|
||||
f"x{len(call_ranges) - (idx + 1)}"
|
||||
for idx in range(len(localize_buffer_handler.local_buf.get_layout().size))
|
||||
]
|
||||
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined]
|
||||
replacements = {}
|
||||
for x in sorted_symbols:
|
||||
if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined]
|
||||
# Only keep index used by local buffer
|
||||
replacements[x] = sympy.core.numbers.Zero()
|
||||
index = sympy_subs(index, replacements) # type: ignore[arg-type]
|
||||
return index
|
||||
|
||||
|
||||
def rewrite_index_for_nodes(
|
||||
localize_buffer_handler: "LocalizeBufferHandler",
|
||||
index: sympy.Expr,
|
||||
):
|
||||
used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)}
|
||||
index_vars = []
|
||||
for i in range(len(localize_buffer_handler.local_buf.get_size())):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
||||
index_vars.append(var if var in used_vars else 0)
|
||||
index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars)
|
||||
return index
|
||||
|
||||
|
||||
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
def __init__(
|
||||
self,
|
||||
inner,
|
||||
global_buf: ir.Buffer,
|
||||
local_buf: ir.Buffer,
|
||||
rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr],
|
||||
):
|
||||
super().__init__(inner)
|
||||
self.global_buf = global_buf
|
||||
self.local_buf = local_buf
|
||||
self.rewrite_index = rewrite_index
|
||||
|
||||
def localize(self, name: str, index: sympy.Expr):
|
||||
if self.global_buf and name == self.global_buf.get_name():
|
||||
assert self.rewrite_index is not None
|
||||
name = self.local_buf.get_name()
|
||||
index = self.rewrite_index(self, index)
|
||||
return name, index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return self._inner.load(*self.localize(name, index))
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
local_buffer_name, local_buffer_index = self.localize(name, index)
|
||||
res = self._inner.store(local_buffer_name, local_buffer_index, value, mode)
|
||||
if (
|
||||
self.global_buf
|
||||
and name == self.global_buf.get_name()
|
||||
and isinstance(V.kernel, Kernel)
|
||||
):
|
||||
# Remove name of local buffer from Kernel.store_buffer_names
|
||||
# local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store.
|
||||
V.kernel.store_buffer_names.discard(local_buffer_name)
|
||||
return res
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
return self._inner.store_reduction(*self.localize(name, index), value)
|
||||
|
||||
|
||||
class LocalBufferContext:
|
||||
"""
|
||||
This class creates a context that helps to generate code involving Inductor IR with
|
||||
function local buffers. These buffers are constructed during the codegen process and
|
||||
|
|
@ -314,10 +394,13 @@ class LocalBufferScope:
|
|||
these buffers without exposure to the outside world.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel: Kernel):
|
||||
self.kernel = kernel
|
||||
def __init__(self, kernel_args: KernelArgs):
|
||||
self.kernel_args = kernel_args
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
# Map Local Buffer name to Local Buffer
|
||||
self.local_buffers: Dict[str, ir.Buffer] = {}
|
||||
# Map Local Buffer name to Global Buffer
|
||||
self.local_to_global: Dict[str, ir.Buffer] = {}
|
||||
|
||||
def __enter__(self):
|
||||
self.exit_stack.__enter__()
|
||||
|
|
@ -330,23 +413,26 @@ class LocalBufferScope:
|
|||
|
||||
self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype))
|
||||
|
||||
original_input = self.kernel.args.input
|
||||
original_input = self.kernel_args.input
|
||||
|
||||
def input(name):
|
||||
if name in self.local_buffers:
|
||||
return name
|
||||
return original_input(name)
|
||||
|
||||
self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input))
|
||||
self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input))
|
||||
|
||||
original_output = self.kernel.args.output
|
||||
original_output = self.kernel_args.output
|
||||
|
||||
def output(name):
|
||||
if name in self.local_buffers:
|
||||
return name
|
||||
return original_output(name)
|
||||
|
||||
self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output))
|
||||
self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output))
|
||||
|
||||
# Set current LocalBufferContext into V
|
||||
self.exit_stack.enter_context(V.set_local_buffer_context(self))
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -354,53 +440,64 @@ class LocalBufferScope:
|
|||
self.local_buffers.clear()
|
||||
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def add_local_buffer(self, buffer: ir.Buffer):
|
||||
assert buffer.get_name() not in self.local_buffers
|
||||
self.local_buffers[buffer.get_name()] = buffer
|
||||
def add_local_buffer(
|
||||
self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None
|
||||
):
|
||||
assert local_buffer.get_name() not in self.local_buffers
|
||||
self.local_buffers[local_buffer.get_name()] = local_buffer
|
||||
if global_buffer:
|
||||
self.local_to_global[local_buffer.get_name()] = global_buffer
|
||||
V.graph.removed_buffers.add(global_buffer.get_name())
|
||||
|
||||
def localize_buffer(
|
||||
self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode]
|
||||
def localize_function(
|
||||
self,
|
||||
fn: Callable[..., Any],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_function,
|
||||
):
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
local_buf = local_buffers[0]
|
||||
global_buf = global_buffers[0]
|
||||
|
||||
def inner(node, *index_vars):
|
||||
with V.set_ops_handler(
|
||||
LocalizeBufferHandler(
|
||||
V.get_ops_handler(),
|
||||
global_buf=global_buf,
|
||||
local_buf=local_buf,
|
||||
rewrite_index=rewrite_index,
|
||||
)
|
||||
):
|
||||
return fn(node, *index_vars)
|
||||
|
||||
return inner
|
||||
|
||||
def localize_nodes(
|
||||
self,
|
||||
nodes: List[ir.IRNode],
|
||||
rewrite_index: Callable[
|
||||
["LocalizeBufferHandler", sympy.Expr], sympy.Expr
|
||||
] = rewrite_index_for_nodes,
|
||||
) -> List[ir.IRNode]:
|
||||
"""
|
||||
Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns
|
||||
a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all
|
||||
the loads and stores are redirected to `local_buf`. This helps the fused loops to
|
||||
work on smaller-sized local buffers for better data locality.
|
||||
Given `local_buf` and `global_buf` registered in current `LocalBufferContext`
|
||||
though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf`
|
||||
for the given `nodes` and returns a new list of IR nodes that work on `local_buf`
|
||||
instead of `global_buf`, i.e., all the loads and stores are redirected to
|
||||
`local_buf`. This helps the fused loops to work on smaller-sized local buffers
|
||||
for better data locality.
|
||||
|
||||
The `local_buf` should already be registered in the local scope and the data access
|
||||
is assumed to be contiguous with the same order as the `global_buf`.
|
||||
The the data access of `local_buf` is assumed to be contiguous with the
|
||||
same order as the `global_buf`.
|
||||
"""
|
||||
assert local_buf.get_name() in self.local_buffers
|
||||
assert len(global_buf.get_size()) == len(local_buf.get_size())
|
||||
local_buffers = list(self.local_buffers.values())
|
||||
global_buffers = list(self.local_to_global.values())
|
||||
assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size())
|
||||
assert len(nodes) > 0
|
||||
|
||||
class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined]
|
||||
def __init__(self, inner):
|
||||
super().__init__(inner)
|
||||
|
||||
def localize(self, name: str, index: sympy.Expr):
|
||||
if name == global_buf.get_name():
|
||||
name = local_buf.get_name()
|
||||
used_vars = {
|
||||
s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)
|
||||
}
|
||||
index_vars = []
|
||||
for i in range(len(local_buf.get_size())):
|
||||
var = sympy_index_symbol_with_prefix(SymT.INDEX, i)
|
||||
index_vars.append(var if var in used_vars else 0)
|
||||
index = local_buf.layout.make_indexer()(index_vars)
|
||||
return name, index
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return self._inner.load(*self.localize(name, index))
|
||||
|
||||
def store(self, name, index, value, mode=None):
|
||||
return self._inner.store(*self.localize(name, index), value, mode)
|
||||
|
||||
def store_reduction(self, name, index, value):
|
||||
return self._inner.store_reduction(*self.localize(name, index), value)
|
||||
|
||||
def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper):
|
||||
def wrap_inner_fn_for_node(node: ir.IRNode):
|
||||
loops = node.data if isinstance(node, ir.ComputedBuffer) else node
|
||||
assert isinstance(loops, ir.Loops)
|
||||
new_loops = copy.copy(loops)
|
||||
|
|
@ -411,17 +508,13 @@ class LocalBufferScope:
|
|||
else:
|
||||
new_node = new_loops # type: ignore[assignment]
|
||||
|
||||
new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn)
|
||||
new_loops.inner_fn = self.localize_function(
|
||||
new_loops.inner_fn,
|
||||
rewrite_index,
|
||||
)
|
||||
return new_node
|
||||
|
||||
def inner_fn_wrapper(inner_fn):
|
||||
def inner(index):
|
||||
with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())):
|
||||
return inner_fn(index)
|
||||
|
||||
return inner
|
||||
|
||||
return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes]
|
||||
return [wrap_inner_fn_for_node(node) for node in nodes]
|
||||
|
||||
|
||||
def unify_mask_base_type(
|
||||
|
|
|
|||
|
|
@ -41,9 +41,15 @@ ir_nodes_pre_fusion = 0
|
|||
# counters for tracking to_dtype inserted
|
||||
cpp_to_dtype_count = 0
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CppOuterLoopFusedCount:
|
||||
inner_kernel_number: int
|
||||
local_buffer_number: int = 0
|
||||
|
||||
|
||||
# The length counts the number of outer loop fusions.
|
||||
# Each element counts the number of inner kernels in each outer loop fusion.
|
||||
cpp_outer_loop_fused_inner_counts: List[int] = []
|
||||
cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
|
||||
|
||||
num_comprehensive_padding = 0
|
||||
num_matches_for_scatter_upon_const_tensor = 0
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ from .ops_handler import ( # noqa: F401
|
|||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
from torch._inductor.codegen.cpp_utils import LocalBufferContext
|
||||
from torch._inductor.debug import DebugContext
|
||||
from torch._inductor.graph import GraphLowering
|
||||
from torch._inductor.ir import InterpreterShim
|
||||
|
|
@ -162,6 +163,9 @@ _debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler)
|
|||
_interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler)
|
||||
_aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler)
|
||||
_current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler)
|
||||
_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized(
|
||||
"local_buffer_context", NullHandler
|
||||
)
|
||||
|
||||
|
||||
class OpsValue:
|
||||
|
|
@ -306,6 +310,8 @@ class _V:
|
|||
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
|
||||
set_current_node: Callable[[Any], Any] = _current_node._set_handler
|
||||
get_current_node: Callable[[], Any] = _current_node._get_handler
|
||||
set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler
|
||||
get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler
|
||||
|
||||
@property
|
||||
def ops(self) -> OpsHandler[Any]:
|
||||
|
|
@ -348,5 +354,9 @@ class _V:
|
|||
def current_node(self):
|
||||
return _current_node._get_handler()
|
||||
|
||||
@property
|
||||
def local_buffer_context(self):
|
||||
return _local_buffer_context._get_handler()
|
||||
|
||||
|
||||
V = _V()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user