From 98929ceae3873f18f4747b88cdff708fde107aa7 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Fri, 28 Jun 2024 05:43:07 -0700 Subject: [PATCH] [Inductor][CPP] Enable Local Buffer for Outer loop fusion (#126967) **Summary** Currently, the Inductor CPP backend [generated code](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-wo-local-buffer-py) for `Softmax` with BF16 data type is significantly slower than the [ATen Implementation](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L149). Upon comparing the generated code with ATen, the performance bottleneck appears to be related to the usage of [local buffer in ATen](https://github.com/pytorch/pytorch/blob/9a2beb862d9c30f037c9b2eac878ec3f9227a5e2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159-L160). In the current implementation, the Inductor uses the output buffer of Kernel Group Args to store and load temporary result (such as `exp`), since this buffer is corresponding to a `SchedulerNode`. Each thread accesses a portion of this output buffer via indexing. However, since this buffer (take this `exp` as example) is only utilized internally within decomposed `softmax`, this buffer can be replaced with a thread-local buffer similar to ATen's approach. In this PR, we have introduced the optimizations of `LocalBuffer`. Following this enhancement, the [new generated Inductor code with local buffer](https://gist.github.com/leslie-fang-intel/98f91d43dabed581a1ffe23daf133a65#file-bf16-softmax-generated-code-w-local-buffer-py) for BF16 `Softmax` demonstrates significantly improved performance. Running the benchmark [here](https://gist.github.com/leslie-fang-intel/37d81441237b5139c8295f5e6c4cd31a) to test this BF16 `Softmax` case on an 8480 Xeon server shows similar performance between the Inductor CPP Backend and the ATen implementation. **TestPlan** ``` python -u -m pytest -s -v inductor/test_cpu_repro.py -k test_local_buffer_in_outer_loop_fusion ``` **Next Step** - [ ] Support more than one Local Buffer/Global Buffer Pull Request resolved: https://github.com/pytorch/pytorch/pull/126967 Approved by: https://github.com/jgong5, https://github.com/peterbell10 --- test/inductor/test_cpu_repro.py | 44 +++- torch/_inductor/codegen/cpp.py | 232 +++++++++++++++--- .../_inductor/codegen/cpp_template_kernel.py | 12 +- torch/_inductor/codegen/cpp_utils.py | 213 +++++++++++----- torch/_inductor/metrics.py | 10 +- torch/_inductor/virtualized.py | 10 + 6 files changed, 415 insertions(+), 106 deletions(-) diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 9258a8b5915..d0ad8e3da11 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -2556,6 +2556,7 @@ class CPUReproTests(TestCase): self.common(fn, (x,)) assert metrics.generated_cpp_vec_kernel_count == 0 + @config.patch(fx_graph_cache=False) def test_outer_loop_fusion(self): def fn(x): max = torch.amax(x, dim=-1, keepdim=True) @@ -2567,8 +2568,47 @@ class CPUReproTests(TestCase): torch._dynamo.reset() metrics.reset() self.common(fn, (x,)) - assert len(metrics.cpp_outer_loop_fused_inner_counts) == 1 - assert metrics.cpp_outer_loop_fused_inner_counts[0] == 2 + self.assertEqual( + len(metrics.cpp_outer_loop_fused_inner_counts), + 1, + ) + self.assertEqual( + metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number, + 2, + ) + + @config.patch(fx_graph_cache=False) + def test_local_buffer_in_outer_loop_fusion(self): + def fn(x): + max = torch.nn.functional.softmax(x, dim=-1) + return x - max + + x = torch.randn(4, 12, 1023, 1022) + + with config.patch({"cpp.simdlen": None}): + torch._dynamo.reset() + metrics.reset() + self.common(fn, (x,)) + self.assertEqual( + len(metrics.cpp_outer_loop_fused_inner_counts), + 1, + ) + self.assertEqual( + metrics.cpp_outer_loop_fused_inner_counts[0].inner_kernel_number, + 3, + ) + self.assertEqual( + metrics.cpp_outer_loop_fused_inner_counts[0].local_buffer_number, + 1, + ) + # Check the number of global buffer allocation + torch._dynamo.reset() + metrics.reset() + _, code = run_and_get_cpp_code( + torch._dynamo.optimize("inductor")(fn), + x, + ) + self.assertEqual(code.count("empty_strided_cpu("), 3) def test_argmin(self): def fn(x): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 12ae89fbdd9..9d207014408 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -7,6 +7,7 @@ import logging import math import re import sys +from collections import namedtuple from copy import copy, deepcopy from enum import Enum from typing import Any, cast, Dict, List, Optional, Sequence, Set, Tuple, Union @@ -69,6 +70,7 @@ from .cpp_utils import ( cexpr_index, DTYPE_TO_CPP, INDEX_TYPE, + LocalBufferContext, unify_mask_base_type, value_to_cpp, ) @@ -435,8 +437,6 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode): loop_nest_list: List[LoopNestWithSplit] = [ kernel.loop_nest for kernel in cpp_kernel_proxy_list ] - metrics.cpp_outer_loop_fused_inner_counts.append(len(loop_nest_list)) - kernel_group = cpp_kernel_proxy_list[0].kernel_group def _merge_outer_fusion_loop_levels( @@ -1915,7 +1915,10 @@ class CppKernel(Kernel): threads = parallel_num_threads() assert self.call_ranges is not None kernels = loop_nest.get_kernels() - if any(isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels): + has_outer_loop_kernel = any( + isinstance(kernel, OuterLoopFusedKernel) for kernel in kernels + ) + if has_outer_loop_kernel: assert len(kernels) == 1 assert isinstance(kernels[0], OuterLoopFusedKernel) par_depth = kernels[0].decide_parallel_depth( @@ -2045,6 +2048,31 @@ class CppKernel(Kernel): stack.enter_context(code.indent()) if loop_nest.root: + if ( + has_outer_loop_kernel + and isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + # Allocate local buffer + local_buffers = V.local_buffer_context.local_buffers + assert len(local_buffers.items()) == 1 + local_buffer = next(iter(local_buffers.items()))[1] + # For dynamic size, rename s to ks + local_buf_size = sympy_product( + [ + self.rename_indexing(size_val) + for size_val in local_buffer.get_layout().size + ] + ) + local_buf_dtype = DTYPE_TO_CPP[local_buffer.get_layout().dtype] + allocate = f"std::make_unique<{local_buf_dtype} []>({cexpr(local_buf_size)})" + code.splice( + f"std::unique_ptr<{local_buf_dtype} []> local_buffer = {allocate};" + ) + local_buffer_name = local_buffer.get_name() + code.splice( + f"{local_buf_dtype}* {local_buffer_name} = local_buffer.get();" + ) gen_loops(loop_nest.root) else: gen_kernel(loop_nest.kernel) @@ -3500,6 +3528,18 @@ class CppKernelProxy(CppKernel): return node.codegen(index_vars) fn_list = [functools.partial(fn, node) for node in nodes] + + if ( + isinstance(V.local_buffer_context, LocalBufferContext) + and V.local_buffer_context.local_buffers + ): + fn_list = [ + V.local_buffer_context.localize_function( + fn, + ) + for fn in fn_list + ] + var_sizes_list = [node.group[1] for node in nodes] self.codegen_functions(fn_list, var_sizes_list, vec_dtype) @@ -3807,6 +3847,159 @@ class CppScheduling(BaseScheduling): self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction() ) or self.can_fuse_vertical_outer_loop(node1, node2) + def codegen_outer_loop_node( + self, + node: OuterLoopFusedSchedulerNode, + ): + """ + Generate the code for the outer loop fused scheduler node. + 1. Codegen with fused outer loop: depends on the analysis of + the outer loop fused scheduler node, with or without the local buffer. + 2. If failed, fallback to standard codegen. + """ + kernel_group = self.kernel_group + generated_cpp_vec_kernel_count = metrics.generated_cpp_vec_kernel_count + cpp_kernel_proxy_list: List[CppKernelProxy] = [] + nodes_list: List[List[SchedulerNode]] = [] + assert isinstance(node, OuterLoopFusedSchedulerNode) + + def try_outer_loop_fusion_with_local_buf(node: OuterLoopFusedSchedulerNode): + """ + Codegen code with fused outer loop and local Buffer. + """ + assert isinstance(node, OuterLoopFusedSchedulerNode) + cpp_kernel_proxy_list.clear() + nodes_list.clear() + + def get_call_ranges(node: BaseSchedulerNode): + assert isinstance(node, (SchedulerNode, FusedSchedulerNode)) + nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] + _, (group, reduction_group) = max( + nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + return call_ranges + + LocalBuffer = namedtuple("LocalBuffer", ["local_buf", "global_buf"]) + local_buffers: List[LocalBuffer] = [] + if all( + len(get_call_ranges(_node)) == node.outer_loop_fusion_depth + 1 + for _node in node.get_outer_nodes() + ): + # Ref to the typical case of local buffer + # in https://github.com/pytorch/pytorch/blob/ + # 1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 + # where the buffer is with size of last dim and contiguous. + # Only support this typical case at first. + for scheduler_node in node.get_nodes(): + # all users inside same OuterLoopFusedSchedulerNode + if not scheduler_node.is_reduction() and all( + user.node in node.get_nodes() for user in scheduler_node.users + ): + global_buffer = scheduler_node.node + assert isinstance(global_buffer, ir.ComputedBuffer) + global_buffer_layout = global_buffer.get_layout() + size_offset = node.outer_loop_fusion_depth - len( + get_call_ranges(scheduler_node) + ) + + def is_all_write_read_contiguous(scheduler_node): + contiguous_index_expr = 0 + stride = 1 + for var, range in reversed( + scheduler_node._body.var_ranges.items() + ): + contiguous_index_expr += stride * var + stride *= range + write_index_expr = scheduler_node._body.writes_name2expr[ + scheduler_node.get_name() + ] + + def is_contiguous_index(x): + return x == contiguous_index_expr + + return is_contiguous_index(write_index_expr) and all( + is_contiguous_index( + user.node._body.reads_name2expr[ + scheduler_node.get_name() + ], + ) + for user in scheduler_node.users + ) + + if not ( + global_buffer_layout.is_contiguous() + and not scheduler_node.is_reduction() + and is_all_write_read_contiguous(scheduler_node) + ): + continue + # Local Buffer is a view of global buffer + local_buffer_layout = ir.FixedLayout( + global_buffer_layout.device, + global_buffer_layout.dtype, + global_buffer_layout.size[size_offset:], + global_buffer_layout.stride[size_offset:], + ) + local_buffers.append( + LocalBuffer( + local_buf=ir.Buffer( + "local_buffer_data", local_buffer_layout + ), + global_buf=global_buffer, + ) + ) + # At most 1 node with local buf for each OuterLoopFusedSchedulerNode + break + assert len(local_buffers) in [0, 1] + + with LocalBufferContext(kernel_group.args) as scope: + if len(local_buffers) > 0: + scope.add_local_buffer( + local_buffers[0].local_buf, local_buffers[0].global_buf + ) + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_node.get_nodes()) # type: ignore[arg-type] + cpp_kernel_proxy_list.append(cpp_kernel_proxy) + nodes_list.append(_node.get_nodes()) # type: ignore[arg-type] + + if not node.check_outer_fusion_loop_level_attr( + cpp_kernel_proxy_list, node.outer_loop_fusion_depth + ): + return False + metrics.cpp_outer_loop_fused_inner_counts.append( + metrics.CppOuterLoopFusedCount( + len(cpp_kernel_proxy_list), + local_buffer_number=len(local_buffers), + ) + ) + outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( + cpp_kernel_proxy_list, + ) + kernel_group.finalize_kernel( + outer_fusion_cpp_kernel_proxy, + [_node for _nodes in nodes_list for _node in _nodes], + ) + + return True + + if not try_outer_loop_fusion_with_local_buf(node): + # Reset generated_cpp_vec_kernel_count to codegen again + metrics.generated_cpp_vec_kernel_count = generated_cpp_vec_kernel_count + cpp_kernel_proxy_list.clear() + nodes_list.clear() + # Similar as comment in + # https://github.com/pytorch/pytorch/blob/469383755fe416eb1c41fa724762ad3eaecdff07/torch/_inductor/codegen/cpp.py#L3269-L3272 + # Kernels share the same global contexts like V.graph.wrapper_code, V.kernel.args. + with torch._inductor.config.patch(inplace_buffers=False): + for _node in node.get_outer_nodes(): + assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) + _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] + cpp_kernel_proxy = CppKernelProxy(kernel_group) + cpp_kernel_proxy.codegen_nodes(_nodes) + kernel_group.finalize_kernel(cpp_kernel_proxy, _nodes) + def codegen_node( self, node: Union[OuterLoopFusedSchedulerNode, FusedSchedulerNode, SchedulerNode], @@ -3817,38 +4010,7 @@ class CppScheduling(BaseScheduling): kernel_group = self.kernel_group if isinstance(node, OuterLoopFusedSchedulerNode): - cpp_kernel_proxy_list: List[CppKernelProxy] = [] - nodes_list: List[List[SchedulerNode]] = [] - - for _node in node.get_outer_nodes(): - assert isinstance(_node, (FusedSchedulerNode, SchedulerNode)) - _nodes: List[SchedulerNode] = _node.get_nodes() # type: ignore[assignment] - cpp_kernel_proxy = CppKernelProxy(kernel_group) - cpp_kernel_proxy.codegen_nodes(_nodes) - - cpp_kernel_proxy_list.append(cpp_kernel_proxy) - nodes_list.append(_nodes) - - # Note that, in the future, when every kernel can be vectorized, - # the function select_tiling will be much easier, and we'll be able to lift - # check_outer_fusion_loop_level_attr to the fusion phase, - # avoiding grouping kernels at fusion time that "look like we'll be able to fuse them" - # but then we actually won't. - if node.check_outer_fusion_loop_level_attr( - cpp_kernel_proxy_list, node.outer_loop_fusion_depth - ): - # Merge the cpp_kernel_proxy_list into cpp_kernel_proxy - outer_fusion_cpp_kernel_proxy = node.merge_outer_fusion_kernels( - cpp_kernel_proxy_list, - ) - kernel_group.finalize_kernel( - outer_fusion_cpp_kernel_proxy, - [_node for _nodes in nodes_list for _node in _nodes], - ) - else: - # Fall back to standard loop codegen - for _kernel_proxy, _nodes in zip(cpp_kernel_proxy_list, nodes_list): - kernel_group.finalize_kernel(_kernel_proxy, _nodes) + self.codegen_outer_loop_node(node) else: nodes: List[SchedulerNode] = node.get_nodes() # type: ignore[assignment] cpp_kernel_proxy = CppKernelProxy(kernel_group) diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 58adcdac480..a27ed7b9e57 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -14,7 +14,7 @@ from ..select_algorithm import PartialRender from ..utils import sympy_index_symbol, sympy_index_symbol_with_prefix from ..virtualized import V from .cpp import CppKernel, CppKernelProxy, KernelGroup -from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferScope +from .cpp_utils import cexpr_index, DTYPE_TO_CPP, LocalBufferContext def parse_expr_with_index_symbols(expr): @@ -270,13 +270,11 @@ class CppTemplateKernel(CppKernel): if offsets: offsets = parse_expr_with_index_symbols(offsets) if epilogue_nodes: - with LocalBufferScope(self) as scope: + with LocalBufferContext(self.args) as scope: assert orig_src is not None if orig_src.get_name() != src.get_name(): - scope.add_local_buffer(src) - epilogue_nodes = scope.localize_buffer( - orig_src, src, epilogue_nodes - ) + scope.add_local_buffer(src, orig_src) + epilogue_nodes = scope.localize_nodes(epilogue_nodes) return self.store_pointwise_nodes( dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type] ) @@ -284,7 +282,7 @@ class CppTemplateKernel(CppKernel): if dst.get_name() != src.get_name(): # src is local copy = L.copy(dst, src).data.data - with LocalBufferScope(self) as scope: + with LocalBufferContext(self.args) as scope: scope.add_local_buffer(src) return self.store_pointwise_nodes(dst, [copy]) else: diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 886a2961aec..b27975c459f 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -4,7 +4,7 @@ import copy import math from collections import namedtuple -from typing import Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch import sympy @@ -12,11 +12,10 @@ import sympy import torch from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import ir -from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix +from ..utils import IndentedBuffer, sympy_index_symbol_with_prefix, sympy_subs from ..virtualized import V -from .common import CSEVariable, ExprPrinter, Kernel - +from .common import CSEVariable, ExprPrinter, Kernel, KernelArgs DTYPE_TO_CPP = { torch.float32: "float", @@ -304,7 +303,88 @@ def value_to_cpp(value, cpp_type): return f"static_cast<{cpp_type}>({repr(value)})" -class LocalBufferScope: +def rewrite_index_for_function( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, +): + # Local buffer at the inner dimensions + snode = V.graph.scheduler.name_to_node.get( + localize_buffer_handler.global_buf.get_name() + ) + assert snode is not None + scheduler_nodes = snode.get_nodes() + _, (group, reduction_group) = max( + scheduler_nodes, key=lambda x: int(x.is_reduction()) + ).group + call_ranges = tuple(group) + tuple(reduction_group) + indices_to_keep = [ + f"x{len(call_ranges) - (idx + 1)}" + for idx in range(len(localize_buffer_handler.local_buf.get_layout().size)) + ] + sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) # type: ignore[attr-defined] + replacements = {} + for x in sorted_symbols: + if x.name.startswith("x") and x.name not in indices_to_keep: # type: ignore[attr-defined] + # Only keep index used by local buffer + replacements[x] = sympy.core.numbers.Zero() + index = sympy_subs(index, replacements) # type: ignore[arg-type] + return index + + +def rewrite_index_for_nodes( + localize_buffer_handler: "LocalizeBufferHandler", + index: sympy.Expr, +): + used_vars = {s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX)} + index_vars = [] + for i in range(len(localize_buffer_handler.local_buf.get_size())): + var = sympy_index_symbol_with_prefix(SymT.INDEX, i) + index_vars.append(var if var in used_vars else 0) + index = localize_buffer_handler.local_buf.layout.make_indexer()(index_vars) + return index + + +class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] + def __init__( + self, + inner, + global_buf: ir.Buffer, + local_buf: ir.Buffer, + rewrite_index: Callable[["LocalizeBufferHandler", sympy.Expr], sympy.Expr], + ): + super().__init__(inner) + self.global_buf = global_buf + self.local_buf = local_buf + self.rewrite_index = rewrite_index + + def localize(self, name: str, index: sympy.Expr): + if self.global_buf and name == self.global_buf.get_name(): + assert self.rewrite_index is not None + name = self.local_buf.get_name() + index = self.rewrite_index(self, index) + return name, index + + def load(self, name: str, index: sympy.Expr): + return self._inner.load(*self.localize(name, index)) + + def store(self, name, index, value, mode=None): + local_buffer_name, local_buffer_index = self.localize(name, index) + res = self._inner.store(local_buffer_name, local_buffer_index, value, mode) + if ( + self.global_buf + and name == self.global_buf.get_name() + and isinstance(V.kernel, Kernel) + ): + # Remove name of local buffer from Kernel.store_buffer_names + # local_buffer_name is added to Kernel.store_buffer_names in Kernel.CSEProxy.store. + V.kernel.store_buffer_names.discard(local_buffer_name) + return res + + def store_reduction(self, name, index, value): + return self._inner.store_reduction(*self.localize(name, index), value) + + +class LocalBufferContext: """ This class creates a context that helps to generate code involving Inductor IR with function local buffers. These buffers are constructed during the codegen process and @@ -314,10 +394,13 @@ class LocalBufferScope: these buffers without exposure to the outside world. """ - def __init__(self, kernel: Kernel): - self.kernel = kernel + def __init__(self, kernel_args: KernelArgs): + self.kernel_args = kernel_args self.exit_stack = contextlib.ExitStack() + # Map Local Buffer name to Local Buffer self.local_buffers: Dict[str, ir.Buffer] = {} + # Map Local Buffer name to Global Buffer + self.local_to_global: Dict[str, ir.Buffer] = {} def __enter__(self): self.exit_stack.__enter__() @@ -330,23 +413,26 @@ class LocalBufferScope: self.exit_stack.enter_context(patch.object(V.graph, "get_dtype", get_dtype)) - original_input = self.kernel.args.input + original_input = self.kernel_args.input def input(name): if name in self.local_buffers: return name return original_input(name) - self.exit_stack.enter_context(patch.object(self.kernel.args, "input", input)) + self.exit_stack.enter_context(patch.object(self.kernel_args, "input", input)) - original_output = self.kernel.args.output + original_output = self.kernel_args.output def output(name): if name in self.local_buffers: return name return original_output(name) - self.exit_stack.enter_context(patch.object(self.kernel.args, "output", output)) + self.exit_stack.enter_context(patch.object(self.kernel_args, "output", output)) + + # Set current LocalBufferContext into V + self.exit_stack.enter_context(V.set_local_buffer_context(self)) return self @@ -354,53 +440,64 @@ class LocalBufferScope: self.local_buffers.clear() self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - def add_local_buffer(self, buffer: ir.Buffer): - assert buffer.get_name() not in self.local_buffers - self.local_buffers[buffer.get_name()] = buffer + def add_local_buffer( + self, local_buffer: ir.Buffer, global_buffer: Optional[ir.Buffer] = None + ): + assert local_buffer.get_name() not in self.local_buffers + self.local_buffers[local_buffer.get_name()] = local_buffer + if global_buffer: + self.local_to_global[local_buffer.get_name()] = global_buffer + V.graph.removed_buffers.add(global_buffer.get_name()) - def localize_buffer( - self, global_buf: ir.Buffer, local_buf: ir.Buffer, nodes: List[ir.IRNode] + def localize_function( + self, + fn: Callable[..., Any], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr], sympy.Expr + ] = rewrite_index_for_function, + ): + local_buffers = list(self.local_buffers.values()) + global_buffers = list(self.local_to_global.values()) + local_buf = local_buffers[0] + global_buf = global_buffers[0] + + def inner(node, *index_vars): + with V.set_ops_handler( + LocalizeBufferHandler( + V.get_ops_handler(), + global_buf=global_buf, + local_buf=local_buf, + rewrite_index=rewrite_index, + ) + ): + return fn(node, *index_vars) + + return inner + + def localize_nodes( + self, + nodes: List[ir.IRNode], + rewrite_index: Callable[ + ["LocalizeBufferHandler", sympy.Expr], sympy.Expr + ] = rewrite_index_for_nodes, ) -> List[ir.IRNode]: """ - Localizes the buffer `global_buf` to `local_buf` in the given `nodes` and returns - a new list of IR nodes that work on `local_buf` instead of `global_buf`, i.e., all - the loads and stores are redirected to `local_buf`. This helps the fused loops to - work on smaller-sized local buffers for better data locality. + Given `local_buf` and `global_buf` registered in current `LocalBufferContext` + though the method of `add_local_buffer`, localizes the `global_buf` to `local_buf` + for the given `nodes` and returns a new list of IR nodes that work on `local_buf` + instead of `global_buf`, i.e., all the loads and stores are redirected to + `local_buf`. This helps the fused loops to work on smaller-sized local buffers + for better data locality. - The `local_buf` should already be registered in the local scope and the data access - is assumed to be contiguous with the same order as the `global_buf`. + The the data access of `local_buf` is assumed to be contiguous with the + same order as the `global_buf`. """ - assert local_buf.get_name() in self.local_buffers - assert len(global_buf.get_size()) == len(local_buf.get_size()) + local_buffers = list(self.local_buffers.values()) + global_buffers = list(self.local_to_global.values()) + assert len(global_buffers[0].get_size()) == len(local_buffers[0].get_size()) assert len(nodes) > 0 - class LocalizeBufferHandler(V.WrapperHandler): # type: ignore[name-defined] - def __init__(self, inner): - super().__init__(inner) - - def localize(self, name: str, index: sympy.Expr): - if name == global_buf.get_name(): - name = local_buf.get_name() - used_vars = { - s for s in index.free_symbols if symbol_is_type(s, SymT.INDEX) - } - index_vars = [] - for i in range(len(local_buf.get_size())): - var = sympy_index_symbol_with_prefix(SymT.INDEX, i) - index_vars.append(var if var in used_vars else 0) - index = local_buf.layout.make_indexer()(index_vars) - return name, index - - def load(self, name: str, index: sympy.Expr): - return self._inner.load(*self.localize(name, index)) - - def store(self, name, index, value, mode=None): - return self._inner.store(*self.localize(name, index), value, mode) - - def store_reduction(self, name, index, value): - return self._inner.store_reduction(*self.localize(name, index), value) - - def wrap_inner_fn_for_node(node: ir.IRNode, inner_fn_wrapper): + def wrap_inner_fn_for_node(node: ir.IRNode): loops = node.data if isinstance(node, ir.ComputedBuffer) else node assert isinstance(loops, ir.Loops) new_loops = copy.copy(loops) @@ -411,17 +508,13 @@ class LocalBufferScope: else: new_node = new_loops # type: ignore[assignment] - new_loops.inner_fn = inner_fn_wrapper(new_loops.inner_fn) + new_loops.inner_fn = self.localize_function( + new_loops.inner_fn, + rewrite_index, + ) return new_node - def inner_fn_wrapper(inner_fn): - def inner(index): - with V.set_ops_handler(LocalizeBufferHandler(V.get_ops_handler())): - return inner_fn(index) - - return inner - - return [wrap_inner_fn_for_node(node, inner_fn_wrapper) for node in nodes] + return [wrap_inner_fn_for_node(node) for node in nodes] def unify_mask_base_type( diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index fc7d0e6a7ab..5c86667e744 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -41,9 +41,15 @@ ir_nodes_pre_fusion = 0 # counters for tracking to_dtype inserted cpp_to_dtype_count = 0 + +@dataclasses.dataclass +class CppOuterLoopFusedCount: + inner_kernel_number: int + local_buffer_number: int = 0 + + # The length counts the number of outer loop fusions. -# Each element counts the number of inner kernels in each outer loop fusion. -cpp_outer_loop_fused_inner_counts: List[int] = [] +cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index ac8d3c64014..51ff55a00b7 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -72,6 +72,7 @@ from .ops_handler import ( # noqa: F401 if TYPE_CHECKING: import torch + from torch._inductor.codegen.cpp_utils import LocalBufferContext from torch._inductor.debug import DebugContext from torch._inductor.graph import GraphLowering from torch._inductor.ir import InterpreterShim @@ -162,6 +163,9 @@ _debug: Virtualized[DebugContext] = Virtualized("debug", NullHandler) _interpreter: Virtualized[InterpreterShim] = Virtualized("interpreter", NullHandler) _aot_compilation: Virtualized[bool] = Virtualized("aot_compilation", NullHandler) _current_node: Virtualized[torch.fx.Node] = Virtualized("current_node", NullHandler) +_local_buffer_context: Virtualized[LocalBufferContext] = Virtualized( + "local_buffer_context", NullHandler +) class OpsValue: @@ -306,6 +310,8 @@ class _V: get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler set_current_node: Callable[[Any], Any] = _current_node._set_handler get_current_node: Callable[[], Any] = _current_node._get_handler + set_local_buffer_context: Callable[[Any], Any] = _local_buffer_context._set_handler + get_local_buffer_context: Callable[[], Any] = _local_buffer_context._get_handler @property def ops(self) -> OpsHandler[Any]: @@ -348,5 +354,9 @@ class _V: def current_node(self): return _current_node._get_handler() + @property + def local_buffer_context(self): + return _local_buffer_context._get_handler() + V = _V()