pytorch/torch/utils/_sympy/interp.py
ydwu4 6140facf00 Support SymBool input to torch.compile (#107850)
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
  pred = x.size(0) == 3
  torch.compile(f)(pred, x)

make_fx(f, tracing_mode="symbolic")(x)
```

The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
2023-09-14 21:34:31 +00:00

104 lines
3.0 KiB
Python

"""
This is a simple interpreter for Sympy expressions that dispatches to
classes following the torch._inductor.virtualized calling convention.
For directness, the interpreter takes the handler directly rather than
consulting the TLS. It does not use most of the methods on the full
handler; only those with corresponding Sympy expressions. To see an example
of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
"""
import functools
from typing import Any, Dict, Union
import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
import torch
from .functions import CleanDiv, FloorDiv, Mod, ModularIndexing, Where
# TODO: Dedupe this with SYMPY_INTERP
@functools.lru_cache(None)
def handlers():
from torch.fx.experimental.symbolic_shapes import Pow, TrueDiv
# TODO add CeilDiv (it doesn't appear in the index_expr)
# TODO default to some decompositions if the interpreter doesn't have them
# like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
HANDLERS = {
sympy.Or: "or_",
sympy.And: "and_",
sympy.Eq: "eq",
sympy.Ne: "ne",
sympy.Lt: "lt",
sympy.Gt: "gt",
sympy.Le: "le",
sympy.Ge: "ge",
sympy.Not: "not_",
TrueDiv: "truediv",
FloorDiv: "floordiv",
CleanDiv: "div",
Where: "where",
sympy.Add: "add",
sympy.Mul: "mul",
Pow: "pow",
sympy.Pow: "pow",
Mod: "mod",
sympy.Mod: "mod",
sympy.Abs: "abs",
sympy.log: "log",
sympy.exp: "exp",
sympy.floor: "floor",
sympy.ceiling: "ceil",
sympy.Min: "minimum",
sympy.Max: "maximum",
ModularIndexing: "modular_indexing",
sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
sympy.Piecewise: "piecewise",
}
return HANDLERS
ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
def sympy_interp(
analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean]
):
# Handle base cases
dtype = None
if isinstance(expr, BooleanAtom):
dtype = torch.bool
elif isinstance(expr, sympy.Integer):
dtype = torch.int64
elif isinstance(expr, sympy.Number):
dtype = torch.double
if dtype is not None:
return analysis.constant(expr, dtype)
elif isinstance(expr, sympy.Symbol):
return env[expr]
# Special cases
if isinstance(expr, sympy.Pow) and isinstance(
expr.args[1], sympy.core.numbers.Half
):
return analysis.sqrt(sympy_interp(analysis, env, expr.args[0]))
# Recursive case
args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type]
handler_name = handlers()[expr.func]
handler = getattr(analysis, handler_name)
if handler_name in ASSOCIATIVE_OPS:
assert len(args) > 1
acc = handler(args[0], args[1])
for i in range(2, len(args)):
acc = handler(acc, args[i])
return acc
else:
return handler(*args)