[ez] add docblock and comments to simd.split_and_set_ranges (#156717)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156717
Approved by: https://github.com/BoyuanFeng
ghstack dependencies: #156445
This commit is contained in:
bobrenjc93 2025-06-25 14:49:27 +00:00 committed by PyTorch MergeBot
parent 204db27a0c
commit 451b525bf0

View File

@ -755,13 +755,35 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
def split_and_set_ranges(
self, lengths: Sequence[Sequence[sympy.Expr]]
) -> list[list[sympy.Expr]]:
"""
Split and set iteration ranges for the kernel based on the provided lengths.
This method maps the kernel's tiling structure to the node's iteration space,
handling both pointwise and reduction dimensions appropriately.
Args:
lengths: A sequence of sequences of symbolic expressions representing
the sizes of different dimensions for each node.
Returns:
A list of lists of symbolic expressions representing the mapped
iteration variables for each dimension.
"""
# Create a dictionary mapping each range tree prefix to its total number of elements
tiling = {rt.prefix: rt.numel for rt in self.range_trees}
# If we're not inside a reduction loop, set all reduction dimensions to 1
# This effectively disables reduction dimensions when not needed
if not self.inside_reduction:
for prefix in tiling:
if prefix_is_reduction(prefix):
tiling[prefix] = sympy.S.One
# Extract the values from the tiling dictionary to create groups
groups = [*tiling.values()]
# Map the kernel's group structure to the node's sizes and set the ranges
# using the set_ranges method, returning the resulting iteration variables
return self.map_kernel_groups_to_node_sizes(groups, lengths, self.set_ranges)
@classmethod