pytorch/torch/_inductor/bounds.py
Peter Bell 02196c21ac [inductor] Parameterize ir.Scan on combine_fn (#109132)
This replaces `tl.cumsum` and `tl.cumprod` with calls to `tl.associative_scan`
where the combine function is generated from inductor IR.

So before we had:
```python
@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.cumsum(tmp2, 1)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Now we have:
```python
@triton.jit
def _triton_helper_fn0(arg0, arg1):
    tmp0 = tmp0 + tmp1
    return tmp0

@triton.jit
def triton_(in_ptr0, out_ptr0, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 20
    rnumel = 30
    RBLOCK: tl.constexpr = 32
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    rmask = rindex < rnumel
    r1 = rindex
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (r1 + (30*x0)), rmask & xmask, other=0).to(tl.float32)
    tmp1 = tl.broadcast_to(tmp0, [XBLOCK, RBLOCK])
    tmp2 = tl.where(rmask & xmask, tmp1, 0)
    tmp3 = tl.associative_scan(tmp2, 1, _triton_helper_fn0)
    tl.store(out_ptr0 + (r1 + (30*x0)), tmp3, rmask & xmask)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109132
Approved by: https://github.com/lezcano
2023-12-12 16:30:50 +00:00

125 lines
5.1 KiB
Python

import operator
from functools import partial
from typing import Any, Callable, Dict
from sympy import Expr
import torch
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
from .ir import InterpreterShim, LoopBody, LoopBodyBlock
from .utils import cache_on_self, dominated_nodes
from .virtualized import V
class BoundVars:
"""
Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
It exposes the ranges of the nodes in the `bounds` variable
Note. A current limitation of this analysis is that it just works on a per-loop basis.
We should be able to propagate the bounds between across the whole graph. This may benefit
the case a bounded variable is returned by a kernel and fed into another.
"""
def __init__(self, loop_body: LoopBody) -> None:
self.loop_body = loop_body
self.replacement_vals = {
k: ValueRanges(0, v - 1)
if (isinstance(v, int) or v.is_number)
else bound_sympy(v)
for k, v in loop_body.var_ranges.items()
}
# avoid computing these values, pessimistically assume that they are unbounded
self.unbounded_vars = dominated_nodes(
node
for node in self.loop_body.get_nodes()
if node.target in ["load", "reduction", operator.getitem]
or "masked_subblock" in node.target
)
# To access this variable call `get_bounds()`
self._bounds: Dict[torch.fx.Node, ValueRanges] = {}
@cache_on_self
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges]:
submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables
for node in self.unbounded_vars:
# we need to evaluate masked_subblock to recurse, and we need to set indirect values
if not isinstance(node.target, str) or (
"masked_subblock" not in node.target
and "set_indirect" not in node.target
):
self._bounds[node] = ValueRanges.unknown()
with V.set_ops_handler(ValueRangeAnalysis()):
interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
return self._bounds
def swap_submodules(
self, submodules: Dict[str, Callable[..., Any]]
) -> Dict[str, Callable[..., ValueRanges]]:
result: Dict[str, Callable[..., ValueRanges]] = {}
for key in submodules.keys():
if key == "get_index":
result[key] = self.get_index
elif "masked_subblock" in key:
subblock = self.loop_body.subblocks[key]
# The result within the lambda will reference to the final
# set of modules at the end of the for-loop as it stores a reference to it
# bind subblock in a function because python lambdas close over by reference
# moving the lambda out of make_fn would close over the reference to subblock,
# so all lambdas would have the same subblock reference that is the final
# subblock in the loop
def make_fn(subblock):
return lambda mask, value: self.masked_subblock(
subblock, self._bounds, mask, value, result
)
result[key] = make_fn(subblock)
elif "set_indirect" in key:
idx = int(key[len("set_indirect") :])
var = self.loop_body.indirect_vars[idx]
indirect = partial(self.set_indirect, var)
result[key] = indirect
else:
assert "scan" in key
result[key] = submodules[key]
return result
def masked_subblock(
self,
subblock: LoopBodyBlock,
env: Dict[torch.fx.Node, ValueRanges],
mask: Any,
value: Any,
submodules: Dict[str, Callable[..., Any]],
) -> ValueRanges:
interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env)
output = [node for node in subblock.graph.nodes if node.target == "output"]
assert len(output) == 1
# dont bother unioning with value since the load from buffer will be
# pessimistically assumed to be inf anyway
return interp.env[output[0]]
def set_indirect(self, old: Expr, new: ValueRanges) -> ValueRanges:
assert isinstance(new, ValueRanges)
self.replacement_vals[old] = new
return new
def get_index(self, name: Expr) -> ValueRanges:
expr = self.loop_body.indexing_exprs[name]
bound = self.replacement_vals.get(expr)
if bound is None:
bound = bound_sympy(expr, self.replacement_vals)
# The following assertion is true at the time of this writing
# We don't assert is as to not execute bound_sympy when bound is not None
# assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
self.replacement_vals[name] = bound
return bound