mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This pr uses the coalescing information in generating a tiling. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes. In triton heuristics, for generating 3d tiled reductions, we take the same total block size that the 2d reduction would use, then distribute the block according to whichever block coalesces the most memory. The motivating kernel is in https://github.com/pytorch/pytorch/issues/149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor. While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075). Pull Request resolved: https://github.com/pytorch/pytorch/pull/153751 Approved by: https://github.com/jansel ghstack dependencies: #153723, #153730, #153748
756 lines
25 KiB
Python
756 lines
25 KiB
Python
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import sys
|
|
from collections import Counter, defaultdict
|
|
from collections.abc import Iterable, Iterator
|
|
from typing import Callable, Literal, Optional, overload, TYPE_CHECKING, TypeVar, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.dependencies import index_vars_no_squeeze
|
|
from torch._inductor.utils import sympy_product, sympy_subs
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._sympy.functions import Identity
|
|
from torch.utils._sympy.solve import try_solve
|
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
|
|
|
from .virtualized import V
|
|
|
|
|
|
T = TypeVar("T")
|
|
U = TypeVar("U")
|
|
|
|
|
|
Split = tuple[sympy.Expr, ...]
|
|
VarsAndRanges = tuple[list[sympy.Symbol], list[sympy.Expr]]
|
|
|
|
|
|
loop_tiling_log = torch._logging.getArtifactLogger(__name__, "loop_tiling")
|
|
from torch.utils._sympy.functions import FloorDiv, ModularIndexing
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._inductor.scheduler import FusedSchedulerNode, SchedulerNode
|
|
|
|
|
|
def solve_for_zero(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
"""
|
|
Given an expr with a single free symbol, solve for a constant relation that would make
|
|
this expression 0.
|
|
"""
|
|
if expr.is_constant():
|
|
return None
|
|
elif isinstance(expr, FloorDiv):
|
|
return None
|
|
|
|
assert len(expr.free_symbols) == 1
|
|
free_symbol = next(iter(expr.free_symbols))
|
|
if isinstance(expr, ModularIndexing):
|
|
out = try_solve(sympy.Eq(expr.args[0], expr.args[2]), free_symbol)
|
|
else:
|
|
out = try_solve(sympy.Eq(expr, 0), free_symbol)
|
|
if not out or not out[1].is_constant():
|
|
return None
|
|
return out[1]
|
|
|
|
|
|
def solve_for_tiling(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
"""
|
|
Giving an expr with a single free symbol, try to find a tiling that would
|
|
make the expression coalesced with respect to that symbol.
|
|
|
|
Tiling an expression `x` by `y` means that the expression will now be indexed
|
|
by both the original (x) and by (x * y). So we are looking for a
|
|
multiplicative factor that will make ((x + 1) * y) - (x * y) == 1.
|
|
|
|
To simplify things for sympy, we'll try just x * y == 1, check x(1) and x(0).
|
|
"""
|
|
|
|
if len(expr.free_symbols) == 0:
|
|
return None
|
|
|
|
free_symbol = next(iter(expr.free_symbols))
|
|
|
|
def _solve_simple_expr(expr: sympy.Expr) -> Optional[sympy.Expr]:
|
|
assert not expr.has(ModularIndexing) and not expr.has(FloorDiv)
|
|
if len(expr.free_symbols) != 1:
|
|
return None
|
|
|
|
out = try_solve(sympy.Eq(expr, 1), free_symbol)
|
|
if not out or not out[1].is_constant():
|
|
return None
|
|
return out[1]
|
|
|
|
# Sympy solving is very limited with ModularIndexing and FloorDiv,
|
|
# but good otherwise.
|
|
if not expr.has(ModularIndexing) and not expr.has(FloorDiv):
|
|
return _solve_simple_expr(expr)
|
|
|
|
required_values = []
|
|
eq_1_expressions = []
|
|
|
|
# very piecemeal solution if ModularIndexing or FloorDiv involved.
|
|
# Look for terms we'll try to make 0, and then other terms we'll try to make 1.
|
|
# Expand as needed.
|
|
for arg in sympy.Add.make_args(expr):
|
|
# Try to make mul terms 0
|
|
if isinstance(arg, sympy.Mul):
|
|
seen = False
|
|
# TODO - only need one of these to be solvable to zero
|
|
#
|
|
for mul_arg in arg.args:
|
|
out = solve_for_zero(mul_arg)
|
|
if out is None:
|
|
continue
|
|
|
|
assert out.is_constant()
|
|
seen = True
|
|
required_values.append(out)
|
|
|
|
if not seen:
|
|
return None
|
|
else:
|
|
eq_1_expressions.append(arg)
|
|
|
|
if not eq_1_expressions:
|
|
return None
|
|
|
|
eq_1_expr = sum(eq_1_expressions)
|
|
|
|
def indexing_div_rep(
|
|
x: sympy.Expr,
|
|
y: sympy.Expr,
|
|
z: Optional[sympy.Expr] = None,
|
|
) -> sympy.Expr:
|
|
return x / y
|
|
|
|
# For the purposes of tiling/coalesced access, approximate ModularIndexing and FloorDiv
|
|
# then check later
|
|
eq_1_expr_simplified = eq_1_expr.replace(ModularIndexing, indexing_div_rep).replace(
|
|
FloorDiv, indexing_div_rep
|
|
)
|
|
|
|
out = _solve_simple_expr(eq_1_expr_simplified)
|
|
# since we approximated FloorDiv/ModularIndexing, double check here
|
|
if not out or not (sympy_subs(eq_1_expr, {free_symbol: out})) == 1:
|
|
return None
|
|
|
|
required_values.append(out)
|
|
|
|
if len(OrderedSet(required_values)) == 1:
|
|
return required_values[0]
|
|
|
|
return None
|
|
|
|
|
|
def find_coalesced_var(
|
|
index: sympy.Expr, var_ranges: dict[sympy.Expr, int]
|
|
) -> Optional[sympy.Expr]:
|
|
"""
|
|
Try to find the symbol which coalesces this index
|
|
"""
|
|
top_level_terms = sympy.Add.make_args(index)
|
|
for v in var_ranges:
|
|
if v in top_level_terms:
|
|
return v
|
|
|
|
# Approximate analysis by evaluating at 1 and 0
|
|
variables: dict[sympy.Symbol, int] = {}
|
|
for v in index.free_symbols:
|
|
if v in var_ranges:
|
|
variables[v] = 0
|
|
else:
|
|
variables[v] = get_hint(v)
|
|
|
|
zero_index = sympy_subs(index, variables)
|
|
for v in var_ranges.keys():
|
|
variables[v] = 1
|
|
try:
|
|
new_val = sympy_subs(index, variables)
|
|
except ZeroDivisionError:
|
|
loop_tiling_log.info("zero division error %s %s", index, variables)
|
|
continue
|
|
if new_val - zero_index == 1:
|
|
variables[v] = 2
|
|
# in some more complex expressions, 0->1 will be coalesced,
|
|
# but not 1->2
|
|
if (sympy_subs(index, variables) - new_val) == 1:
|
|
return v
|
|
variables[v] = 0
|
|
|
|
return None
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class FusedNormalizedReadsWrites:
|
|
"""
|
|
Normalized reads and writes for nodes in the same FusedSchedulerNode.
|
|
"""
|
|
|
|
index_vars: OrderedSet[sympy.Symbol]
|
|
reduce_vars: OrderedSet[sympy.Symbol]
|
|
reads: dict[sympy.Expr, OrderedSet[str]]
|
|
writes: dict[sympy.Expr, OrderedSet[str]]
|
|
var_ranges: dict[sympy.Symbol, int]
|
|
|
|
|
|
@overload
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: Literal[True],
|
|
) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]: ...
|
|
|
|
|
|
@overload
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: Literal[False] = False,
|
|
) -> tuple[VarsAndRanges, VarsAndRanges]: ...
|
|
|
|
|
|
def get_pw_red_splits(
|
|
n: "SchedulerNode",
|
|
pointwise_numel: sympy.Expr,
|
|
red_numel: sympy.Expr,
|
|
none_if_not_divisible: bool = False,
|
|
) -> Optional[tuple[VarsAndRanges, VarsAndRanges]]:
|
|
if n.is_reduction() or sympy_product(n._body.sizes[0]) == pointwise_numel:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
assert sympy_product(n._body.sizes[0]) == pointwise_numel * red_numel # type: ignore[operator]
|
|
i = len(n._body.sizes[0]) - 1
|
|
prod = 1
|
|
while i >= 0:
|
|
prod *= n._body.sizes[0][i]
|
|
if prod == red_numel:
|
|
break
|
|
i -= 1
|
|
|
|
if i >= 0:
|
|
pw_splits = n._body.sizes[0][0:i]
|
|
iter_vars = n._body.iter_vars[0:i]
|
|
|
|
red_splits = n._body.sizes[0][i:]
|
|
red_vars = n._body.iter_vars[i:]
|
|
return (iter_vars, pw_splits), (red_vars, red_splits) # type: ignore[return-value]
|
|
|
|
if none_if_not_divisible:
|
|
return None
|
|
else:
|
|
return (
|
|
(n._body.iter_vars, n._body.sizes[0]),
|
|
(n._body.reduce_vars, n._body.sizes[1]),
|
|
) # type: ignore[return-value]
|
|
|
|
|
|
class NodeSplitGetter:
|
|
"""
|
|
Finds a Pointwise, Reduction Split that compatible with all nodes in a SchedulerNode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
):
|
|
self.node = node
|
|
self.pointwise_numel: sympy.Expr = node.group[1][0]
|
|
self.red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
self.pw_split_options: dict[int, OrderedSet[Split]] = defaultdict(OrderedSet)
|
|
|
|
self.reduction_split: Split = ()
|
|
self.all_node_sizes: OrderedSet[tuple[Split, Split]] = OrderedSet()
|
|
|
|
fused_group = node.group[1]
|
|
for n in reversed(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
# if we can't split the pw ranges into a (pw, red) split,
|
|
# dont add as a split option, but do make sure we check that this size
|
|
# is splittable
|
|
maybe_splits = get_pw_red_splits(
|
|
n, self.pointwise_numel, self.red_numel, none_if_not_divisible=True
|
|
)
|
|
if maybe_splits is None:
|
|
self.all_node_sizes.add(n._body.sizes)
|
|
continue
|
|
|
|
(_, n_pw_splits), (_, n_red_splits) = maybe_splits
|
|
|
|
# fill in reduction size
|
|
n_pw_splits, n_red_splits = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
fused_group, (n_pw_splits, n_red_splits), self.red_numel
|
|
)
|
|
)
|
|
|
|
self.pw_split_options[len(n_pw_splits)].add(tuple(n_pw_splits))
|
|
|
|
# initially, we are just going to do a single reduction split since
|
|
# reduction tiling is off by default. even if we miss a reduction split,
|
|
# we can recover it in the split var analysis.
|
|
# TODO: an earlier version fo this code tried to iteratively try the maximum number
|
|
# of split vars, by iterating over both pointwise and reduction. but not worth
|
|
# the complexity yet.
|
|
|
|
if n_red_splits != ():
|
|
self.reduction_split = (sympy_product(n_red_splits),)
|
|
|
|
n_size = (tuple(n_pw_splits), tuple(n_red_splits))
|
|
self.all_node_sizes.add(n_size)
|
|
|
|
self.seen_pw_splits: OrderedSet[Split] = OrderedSet()
|
|
|
|
def get_node_splits(self) -> tuple[Split, Split]:
|
|
"""
|
|
Get a compatible pointwise, reduction split of the node
|
|
"""
|
|
|
|
if len(self.all_node_sizes) == 1:
|
|
return next(iter(self.all_node_sizes))
|
|
|
|
max_pw_split = max(self.pw_split_options.keys())
|
|
for pw_split_len in range(max_pw_split, 0, -1):
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
if out := self.try_split(pw_split, self.reduction_split):
|
|
return out
|
|
|
|
# combine dims for next round
|
|
for pw_split in self.pw_split_options[pw_split_len]:
|
|
for i in range(len(pw_split) - 1):
|
|
new_split = tuple(
|
|
pw_split[0:i]
|
|
+ (sympy_product(pw_split[i : i + 2]),)
|
|
+ pw_split[i + 2 :]
|
|
)
|
|
self.pw_split_options[len(new_split)].add(new_split)
|
|
|
|
# if for whatever reason we couldnt split above, return default split
|
|
return ((self.pointwise_numel,), (self.red_numel,))
|
|
|
|
def try_split(self, pw: Split, red: Split) -> Optional[tuple[Split, Split]]:
|
|
"""
|
|
See if this split is compatible, and potentially returning a longer split
|
|
than the input.
|
|
"""
|
|
|
|
from torch._inductor.codegen.simd import CantSplit, SIMDKernel
|
|
|
|
if pw in self.seen_pw_splits:
|
|
return None
|
|
self.seen_pw_splits.add(pw)
|
|
|
|
for n_pw, n_red in self.all_node_sizes:
|
|
try:
|
|
groups = pw + red
|
|
lengths = (n_pw, n_red)
|
|
splits, getters = SIMDKernel._split_iteration_ranges(groups, lengths)
|
|
except CantSplit:
|
|
return None
|
|
|
|
assert len(getters) == 2
|
|
pw_group_splits = splits[: len(pw)]
|
|
# if we had to divide a variable into two to do this split,
|
|
# then lets try the larger, induced split.
|
|
# e.g. splitting (12, 2) into (2, 12) will split the first var into:
|
|
# (2, 6) and produce an overall split of (2, 6, 2)
|
|
flattened_pw_splits = tuple(itertools.chain.from_iterable(pw_group_splits))
|
|
if flattened_pw_splits != pw:
|
|
if out := self.try_split(flattened_pw_splits, red):
|
|
return out
|
|
|
|
return pw, red
|
|
|
|
|
|
if sys.version_info >= (3, 10):
|
|
# On Python 3.10+ we can use zip(strict=True)
|
|
zip_equal = functools.partial(zip, strict=True)
|
|
else:
|
|
# Fallback for older versions
|
|
def zip_equal(it1: Iterable[T], it2: Iterable[U]) -> Iterator[tuple[T, U]]:
|
|
"""
|
|
Zip two iterables, raising ValueError if their lengths differ.
|
|
"""
|
|
if len(it1) != len(it2):
|
|
raise ValueError(f"Lengths differ: {len(it1)} != {len(it2)}")
|
|
return zip(it1, it2)
|
|
|
|
|
|
def apply_var_mapping(
|
|
iter_vars: list[sympy.Symbol],
|
|
red_vars: list[sympy.Symbol],
|
|
norm_pw_vars: list[sympy.Symbol],
|
|
norm_red_vars: list[sympy.Symbol],
|
|
new_ranges: list[list[sympy.Expr]],
|
|
return_getters_groups: list[list[Callable[[list[sympy.Expr]], sympy.Expr]]],
|
|
) -> dict[sympy.Symbol, sympy.Expr]:
|
|
"""Maps original variables to expressions using normalized variables."""
|
|
|
|
# the output of split_iteration_range is a new_ranges, return_getters_groups
|
|
# new_ranges is a flattened list of ranges corresponding to the new pw and red vars
|
|
# for example, taking in pw vars of range (6, 6) to normalized range [36],
|
|
# new_ranges would be [[6, 6]]
|
|
# There is a return_getter callable for each input iter_var and red_vars.
|
|
# if you flatten out all of the ranges, and create a variable for each index,
|
|
# then applying the flattening vars to the callables in return_getters_groups
|
|
# gives you the mapping from input vars -> flattened vars.
|
|
# From there, we can compute the output, normalized variables.
|
|
# For instance [6, 6] corresponding to flat vars v0, v1 will be
|
|
# v0 + 6 * v1
|
|
|
|
# Create flattened iteration variables
|
|
num_vars = sum(len(s) for s in new_ranges)
|
|
flat_vars = sympy.symbols(f"v_0:{num_vars}")
|
|
count = 0
|
|
|
|
if len(iter_vars) == 0 and len(red_vars) == 0:
|
|
return {}
|
|
|
|
assert len(new_ranges) == len(norm_pw_vars + norm_red_vars)
|
|
apply_groups = []
|
|
for group in return_getters_groups:
|
|
apply_groups.append([g(flat_vars) for g in group])
|
|
|
|
iter_vars_to_flat_vars = {}
|
|
for i, (group, var_group) in enumerate(
|
|
zip_equal(apply_groups, ((iter_vars, red_vars)))
|
|
):
|
|
# if the node has sizes (p0, 1) and the fused node is (p0, r0)
|
|
# the reduction var gets filled in for split_iteration_range
|
|
if len(group) != len(var_group):
|
|
assert i == 1
|
|
assert len(var_group) == 0
|
|
continue
|
|
|
|
iter_vars_to_flat_vars.update({v: g for g, v in zip(group, var_group)})
|
|
|
|
count = 0
|
|
flat_vars_to_new_vars = {}
|
|
for new_range, new_var in zip_equal(new_ranges, norm_pw_vars + norm_red_vars):
|
|
range_vars = []
|
|
for i in range(len(new_range)):
|
|
range_vars.append(flat_vars[count])
|
|
count += 1
|
|
|
|
prod = 1
|
|
for i in range(len(new_range) - 1, -1, -1):
|
|
flat_vars_to_new_vars[range_vars[i]] = new_var * prod
|
|
prod = new_range[i] * prod
|
|
|
|
return {
|
|
k: sympy_subs(v, flat_vars_to_new_vars)
|
|
for k, v in iter_vars_to_flat_vars.items()
|
|
}
|
|
|
|
|
|
def extract_normalized_read_writes(
|
|
node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> Optional[FusedNormalizedReadsWrites]:
|
|
"""Extracts index variables, reduce variables, read/write expressions, and variable ranges from a fused node."""
|
|
reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
all_output_names = node.get_buffer_names()
|
|
op_names = node.get_operation_names()
|
|
outputs: OrderedSet[str] = OrderedSet()
|
|
removed_buffers: OrderedSet[str] = OrderedSet()
|
|
for buf_name in all_output_names:
|
|
if V.graph.scheduler.can_buffer_be_removed_through_fusion(buf_name, op_names):
|
|
removed_buffers.add(buf_name)
|
|
else:
|
|
outputs.add(buf_name)
|
|
|
|
inputs = OrderedSet(
|
|
dep.name for dep in node.read_writes.reads if dep.name not in removed_buffers
|
|
)
|
|
|
|
pointwise_numel: sympy.Expr = node.group[1][0]
|
|
red_numel: sympy.Expr = node.group[1][1]
|
|
|
|
# TODO - a few dynamic shapes issues to resolve
|
|
if any(
|
|
(isinstance(var, sympy.Expr) and not var.is_constant())
|
|
for var in (pointwise_numel, red_numel)
|
|
):
|
|
return None
|
|
|
|
pw_splits, red_splits = NodeSplitGetter(node).get_node_splits()
|
|
|
|
# lets use different prefix (`n`) to distinguish
|
|
(norm_pw_vars, norm_red_vars), ranges = index_vars_no_squeeze(
|
|
pw_splits, red_splits, prefix="n"
|
|
)
|
|
node = node
|
|
|
|
for n in list(node.get_nodes()):
|
|
if not isinstance(n, torch._inductor.scheduler.SchedulerNode):
|
|
continue
|
|
|
|
body = n._body
|
|
|
|
# TODO - not handled well. indirect loads will not be coalesced,
|
|
# need to account for that in analysis.
|
|
if body.indirect_vars:
|
|
return None
|
|
|
|
n_reads: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
n_writes: dict[sympy.Expr, OrderedSet[str]] = defaultdict(OrderedSet)
|
|
|
|
# TODO - will the names for all the inputs/outputs accurately
|
|
# reflect mutation, or do I need to remap with mutation_real_name
|
|
for inp in inputs:
|
|
for expr in body.get_all_read_expr(inp):
|
|
n_reads[expr].add(inp)
|
|
|
|
for out in outputs:
|
|
for expr in body.get_all_write_expr(out):
|
|
n_writes[expr].add(out)
|
|
|
|
if not n_reads and not n_writes:
|
|
continue
|
|
|
|
(iter_vars, n_pw_splits), (red_vars, n_red_splits) = get_pw_red_splits(
|
|
n, pointwise_numel, red_numel
|
|
)
|
|
|
|
groups = pw_splits + red_splits
|
|
lengths = (n_pw_splits, (n_red_splits))
|
|
lengths = (
|
|
torch._inductor.codegen.simd.SIMDKernel.prepare_split_iteration_lengths(
|
|
groups, lengths, red_numel
|
|
)
|
|
)
|
|
new_ranges, return_getters_groups = (
|
|
torch._inductor.codegen.simd.SIMDKernel._split_iteration_ranges(
|
|
groups, lengths
|
|
)
|
|
)
|
|
var_map = apply_var_mapping(
|
|
iter_vars,
|
|
red_vars,
|
|
norm_pw_vars,
|
|
norm_red_vars,
|
|
new_ranges,
|
|
return_getters_groups,
|
|
)
|
|
|
|
# We create Identity sympy.Functions to prevent expansion to int64,
|
|
# unwrap for tiling analysis.
|
|
def remove_identity(expr: sympy.Expr) -> sympy.Expr:
|
|
return expr.replace(Identity, lambda x: x)
|
|
|
|
n_reads_new = {
|
|
sympy_subs(remove_identity(read), var_map): v for read, v in n_reads.items()
|
|
}
|
|
n_writes_new = {
|
|
sympy_subs(remove_identity(write), var_map): v
|
|
for write, v in n_writes.items()
|
|
}
|
|
|
|
for expr, buf_names in n_reads_new.items():
|
|
reads[expr] |= buf_names
|
|
|
|
for expr, buf_names in n_writes_new.items():
|
|
writes[expr] |= buf_names
|
|
|
|
reads = {
|
|
V.graph.sizevars.simplify_with_ranges(r, ranges): v for r, v in reads.items()
|
|
}
|
|
writes = {
|
|
V.graph.sizevars.simplify_with_ranges(w, ranges): v for w, v in writes.items()
|
|
}
|
|
|
|
fused_out = FusedNormalizedReadsWrites(
|
|
norm_pw_vars, # type: ignore[arg-type]
|
|
norm_red_vars, # type: ignore[arg-type]
|
|
reads,
|
|
writes,
|
|
ranges,
|
|
)
|
|
loop_tiling_log.info("Normalized Fused reads: %s", fused_out)
|
|
return fused_out
|
|
|
|
|
|
def get_score(addr: sympy.Expr, var_ranges: dict[sympy.Symbol, int]) -> int:
|
|
"""
|
|
Score addr according to its approximate size
|
|
"""
|
|
|
|
# TODO - deduplicate with candidate_tilings
|
|
var_sizes = []
|
|
for v in addr.free_symbols:
|
|
v_size = var_ranges.get(v, None)
|
|
# TODO - reason about indirect vars
|
|
if not symbol_is_type(v, SymT.INDIRECT) and v_size is not None:
|
|
var_sizes.append(v_size)
|
|
from .virtualized import V
|
|
|
|
return V.graph.sizevars.atomically_apply_size_hint(
|
|
sympy_product(var_sizes), fallback=config.unbacked_symint_fallback
|
|
)
|
|
|
|
|
|
def get_hint(v: Union[sympy.Expr, int]) -> int:
|
|
if isinstance(v, int):
|
|
return v
|
|
else:
|
|
return V.graph.sizevars.size_hint(v, fallback=config.unbacked_symint_fallback)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class VarTiling:
|
|
"""
|
|
Tiling of a var by `tiling_factor` that yields additional coalesced mem accesses by `benefit_score`
|
|
"""
|
|
|
|
var: sympy.Symbol
|
|
tiling_factor: int
|
|
score: int
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class CoalesceVarAnalysis:
|
|
coalesced_by_var: dict[sympy.Expr, int]
|
|
|
|
norm_read_writes: FusedNormalizedReadsWrites
|
|
|
|
suggested_split: Optional[VarTiling] = None
|
|
|
|
|
|
def analyze_memory_coalescing(
|
|
fused_node: Union["FusedSchedulerNode", "SchedulerNode"],
|
|
) -> Optional[CoalesceVarAnalysis]:
|
|
"""
|
|
Find variables that coalesce the reads and writes and score the total size.
|
|
|
|
If uncoalesced memory expressions are found, look for additionally tiling of variables
|
|
which will coalesce memory accesses.
|
|
|
|
For instance - for the following expression:
|
|
|
|
(32*p0) // 2048
|
|
|
|
Tiling p0 by 64 will make this expression coalesced.
|
|
"""
|
|
|
|
norm_read_writes = extract_normalized_read_writes(fused_node)
|
|
|
|
if norm_read_writes is None:
|
|
return None
|
|
|
|
reads = norm_read_writes.reads
|
|
writes = norm_read_writes.writes
|
|
var_ranges = norm_read_writes.var_ranges
|
|
|
|
coalesced_by_var: dict[sympy.Symbol, int] = Counter()
|
|
uncoalesced_addrs: dict[sympy.Expr, int] = Counter()
|
|
|
|
for memory_expr, buf_names in itertools.chain(reads.items(), writes.items()):
|
|
# skip memory deps with indirect vars - todo: better handling
|
|
indirect_expr = bool(
|
|
memory_expr.free_symbols - norm_read_writes.var_ranges.keys()
|
|
)
|
|
|
|
if indirect_expr:
|
|
continue
|
|
|
|
size = get_score(memory_expr, var_ranges)
|
|
if size == 0:
|
|
continue
|
|
|
|
maybe_coalesced_var = find_coalesced_var(memory_expr, var_ranges)
|
|
|
|
byte_multipler = 0
|
|
for buf_name in buf_names:
|
|
if buf := V.graph.try_get_buffer(buf_name):
|
|
byte_multipler += buf.dtype.itemsize
|
|
|
|
if maybe_coalesced_var:
|
|
coalesced_by_var[maybe_coalesced_var] += size * byte_multipler
|
|
else:
|
|
uncoalesced_addrs[memory_expr] += size * byte_multipler
|
|
|
|
if not uncoalesced_addrs:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
# map from var -> tiling -> total_score
|
|
tiling_scores: dict[sympy.Expr, dict[int, int]] = defaultdict(Counter)
|
|
|
|
for uncoalesced_expr, addr_score in uncoalesced_addrs.items():
|
|
expr_subs = dict.fromkeys(uncoalesced_expr.free_symbols, 0)
|
|
for v in uncoalesced_expr.free_symbols:
|
|
# skip non iter/reduce var variables
|
|
if v not in var_ranges:
|
|
continue
|
|
# skip small addrs
|
|
if addr_score == 0:
|
|
continue
|
|
del expr_subs[v]
|
|
single_var_expr = sympy_subs(uncoalesced_expr, expr_subs)
|
|
expr_subs[v] = 0
|
|
tiling_factor = solve_for_tiling(single_var_expr)
|
|
if (
|
|
tiling_factor is None
|
|
or not tiling_factor.is_constant()
|
|
or not tiling_factor.is_integer
|
|
):
|
|
continue
|
|
|
|
tiling_factor = int(tiling_factor)
|
|
if not V.graph.sizevars.statically_known_lt(tiling_factor, var_ranges[v]):
|
|
continue
|
|
|
|
# TODO - if a var is in the middle, such as [n0, n1, n2]
|
|
# n1 can can be split beyond range
|
|
|
|
MIN_TILING_BLOCK = 8
|
|
if not all(
|
|
V.graph.sizevars.statically_known_lt(MIN_TILING_BLOCK, block)
|
|
for block in (tiling_factor, var_ranges[v] // tiling_factor)
|
|
):
|
|
continue
|
|
|
|
tiling_scores[v][tiling_factor] += addr_score
|
|
|
|
if len(tiling_scores) == 0:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
best_tiling: Optional[tuple[sympy.Expr, int]] = None
|
|
best_tiling_score = 0
|
|
|
|
for var, tiling_counter in tiling_scores.items():
|
|
for tile, tile_score in tiling_counter.items():
|
|
if tile_score > best_tiling_score:
|
|
best_tiling = (var, tile)
|
|
best_tiling_score = tile_score
|
|
|
|
if best_tiling is None:
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var, norm_read_writes=norm_read_writes
|
|
)
|
|
|
|
# TODO - for strictly pointwise fusions,
|
|
# we can consider just swizzling the var if the var we are going to tile
|
|
# does not coalesce a significant portion of global reads
|
|
# TODO - could also prefer index var splits to reduction, better tested
|
|
return CoalesceVarAnalysis(
|
|
coalesced_by_var=coalesced_by_var,
|
|
norm_read_writes=norm_read_writes,
|
|
suggested_split=VarTiling(best_tiling[0], best_tiling[1], best_tiling_score),
|
|
)
|