pytorch/torch/utils/_sympy/reference.py
Pian Pawakapan 8a31c2aa84 [export] allow complex guards as runtime asserts (#127129)
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.

For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](https://github.com/pytorch/pytorch/pull/124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.

In this PR, relying on [hybrid backed-unbacked symints](https://github.com/pytorch/pytorch/issues/121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](d7de4c9d80/torch/fx/experimental/symbolic_shapes.py (L1824)), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.

Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y):  # x: [s0, s1], y: [s2]
    return x.reshape([-1]) + y  # guard s0 * s1 = s2

This leads to the following exported program

class GraphModule(torch.nn.Module):
    def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
        sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
        mul: "Sym(-s2)" = -1 * sym_size_int;  sym_size_int = None
        sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
        sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
        mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2;  sym_size_int_1 = sym_size_int_2 = None
        add: "Sym(s0*s1 - s2)" = mul + mul_1;  mul = mul_1 = None
        eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0;  add = None
        _assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'");  eq = None

        view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]);  x = None
        add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y);  view = y = None
        return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x):  # x: [s0, s1]
    return x.reshape([-1, x.shape[0] - 1])  # Eq(Mod(s0 * s1, s0 - 1), 0)
```

Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following:
```
class Foo(torch.nn.Module):
    def forward(self, x, y):
        # check that negation of first guard also shows up as runtime assertion
        if x.shape[0] == y.shape[0]:  # False
            return x + y
        elif x.shape[0] == y.shape[0] ** 3:  # False
            return x + 2, y + 3
        elif x.shape[0] ** 2 == y.shape[0] * 3:  # True
            return x * 2.0, y * 3.0
```
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.

One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.

As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.

Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](https://github.com/pytorch/pytorch/pull/124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](https://github.com/pytorch/pytorch/pull/126925)

This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.

NOTE: will take naming change suggestions for the flag :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127129
Approved by: https://github.com/avikchaudhuri
2024-05-29 17:15:25 +00:00

225 lines
5.0 KiB
Python

import math
import sympy
import torch
from torch.utils._sympy.functions import (
OpaqueUnaryFn_exp,
OpaqueUnaryFn_log,
OpaqueUnaryFn_sqrt,
)
# The sympy interpretation of operators. It will also sometimes work with
# plain int/float, but if you do certain operations you will get out a
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
# check PythonReferenceAnalysis.
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@staticmethod
def constant(c, dtype):
return sympy.sympify(c)
@staticmethod
def or_(a, b):
return a | b
@staticmethod
def and_(a, b):
return a & b
@staticmethod
def eq(a, b):
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
return sympy.Eq(a, b)
return a == b
@classmethod
def ne(cls, a, b):
return cls.not_(cls.eq(a, b))
@staticmethod
def lt(a, b):
return a < b
@staticmethod
def gt(a, b):
return a > b
@staticmethod
def le(a, b):
return a <= b
@staticmethod
def ge(a, b):
return a >= b
@staticmethod
def not_(a):
assert not isinstance(a, bool)
return ~a
@staticmethod
def reciprocal(x):
return 1 / x
@staticmethod
def square(x):
return x * x
@staticmethod
def mod(x, y):
ret = abs(x) % abs(y)
# without check:
# tracing will fail trying to go through control-flow if x is Proxy()
if isinstance(x, (int, sympy.Number)) and x < 0:
ret *= -1
return ret
@staticmethod
def abs(x):
return abs(x)
@staticmethod
def neg(x):
return -x
@staticmethod
def truediv(a, b):
return a / b
@staticmethod
def div(a, b):
return ReferenceAnalysis.truediv(a, b)
@staticmethod
def floordiv(a, b):
if b == 0:
return sympy.nan if a == 0 else sympy.zoo
return a // b
@staticmethod
def truncdiv(a, b):
result = a / b
if result.is_finite:
result = sympy.Integer(result)
return result
@staticmethod
def add(a, b):
return a + b
@staticmethod
def mul(a, b):
return a * b
@staticmethod
def sub(a, b):
return a - b
@staticmethod
def exp(x):
return OpaqueUnaryFn_exp(x)
@staticmethod
def log(x):
return OpaqueUnaryFn_log(x)
@staticmethod
def sqrt(x):
return OpaqueUnaryFn_sqrt(x)
@staticmethod
def pow(a, b):
return a**b
@staticmethod
def minimum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Min(result_type(a), result_type(b))
@staticmethod
def maximum(a, b):
# Poorman's version of upcasting in Sympy
# This won't do for sympy.Expr as the casting does nothing for those
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
result_type = sympy.Float
else:
assert a.is_Integer
assert b.is_Integer
result_type = sympy.Integer
return sympy.Max(result_type(a), result_type(b))
@staticmethod
def floor(x):
return sympy.floor(x)
@staticmethod
def ceil(x):
return sympy.ceiling(x)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
class PythonReferenceAnalysis(ReferenceAnalysis):
@staticmethod
def constant(c, dtype):
if dtype is torch.int64:
return int(c)
elif dtype is torch.double:
return float(c)
elif dtype is torch.bool:
return bool(c)
else:
raise AssertionError(f"unrecognized dtype {dtype}")
@staticmethod
def not_(a):
return torch.sym_not(a)
@staticmethod
def floordiv(a, b):
return a // b
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
def sqrt(x):
return torch._sym_sqrt(x) # type: ignore[attr-defined]
@staticmethod
def minimum(a, b):
return torch.sym_min(a, b)
@staticmethod
def maximum(a, b):
return torch.sym_max(a, b)
@staticmethod
def floor(x):
return math.floor(x)
@staticmethod
def ceil(x):
return math.ceil(x)