mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
With the current state of export's dynamic shapes, we struggle with guards and constraints that are beyond the current dynamic shapes language, expressed with dims and derived dims. While we can compile and guarantee correctness for guards within the current language (e.g. min/max ranges, linear relationships, integer divisibility) we struggle to dynamically compile guards which extend beyond that.
For these "complex" guards, we typically do either of the following: 1) raise a constraint violation error, along the lines of "not all values of <symbol> in the specified range satisfy <guard>", with or without suggested fixes, 2) specialize to the provided static values and suggest removing dynamism, or 3) fail compilation due to some arbitrary unsupported case. Previous [work](https://github.com/pytorch/pytorch/pull/124949) went towards resolving this by disabling forced specializations, instead allowing the user to fail at runtime with incorrect inputs.
In this PR, relying on [hybrid backed-unbacked symints](https://github.com/pytorch/pytorch/issues/121749), [deferred runtime asserts](https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/runtime_assert.py), and the function [_is_supported_equivalence()](d7de4c9d80/torch/fx/experimental/symbolic_shapes.py (L1824)), we add a flag `_allow_complex_guards_as_runtime_asserts` which allows the user to compile exported programs containing these guards and maintain dynamism, while adding correctness checks as runtime assertions in the graph.
Hybrid backed-unbacked symints allow us to easily bypass "implicit" guards emitted from computation - guards that we ~expect to be true. Popular examples revolve around reshapes:
```
# reshape
def forward(self, x, y): # x: [s0, s1], y: [s2]
return x.reshape([-1]) + y # guard s0 * s1 = s2
This leads to the following exported program
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[s0, s1]", y: "f32[s2]"):
sym_size_int: "Sym(s2)" = torch.ops.aten.sym_size.int(y, 0)
mul: "Sym(-s2)" = -1 * sym_size_int; sym_size_int = None
sym_size_int_1: "Sym(s0)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(x, 1)
mul_1: "Sym(s0*s1)" = sym_size_int_1 * sym_size_int_2; sym_size_int_1 = sym_size_int_2 = None
add: "Sym(s0*s1 - s2)" = mul + mul_1; mul = mul_1 = None
eq: "Sym(Eq(s0*s1 - s2, 0))" = add == 0; add = None
_assert_scalar = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(s0*s1 - s2, 0) on node 'eq'"); eq = None
view: "f32[s0*s1]" = torch.ops.aten.view.default(x, [-1]); x = None
add_1: "f32[s0*s1]" = torch.ops.aten.add.Tensor(view, y); view = y = None
return (add_1,)
```
Another case is symbol divisibility:
```
def forward(self, x): # x: [s0, s1]
return x.reshape([-1, x.shape[0] - 1]) # Eq(Mod(s0 * s1, s0 - 1), 0)
```
Applying deferred runtime asserts also helps dynamic compilation for "explicit" complex guards that typically cause problems for export. For example we can generate runtime asserts for not-equal guards, and complex conditions like the following:
```
class Foo(torch.nn.Module):
def forward(self, x, y):
# check that negation of first guard also shows up as runtime assertion
if x.shape[0] == y.shape[0]: # False
return x + y
elif x.shape[0] == y.shape[0] ** 3: # False
return x + 2, y + 3
elif x.shape[0] ** 2 == y.shape[0] * 3: # True
return x * 2.0, y * 3.0
```
For the above graph we will generate 3 runtime assertions: the negation of the first 2, and the 3rd condition as a guard.
One additional benefit here over the current state of exported programs is that this adds further correctness guarantees - previously with explicit complex guards, if compilation succeeded, the guards would be ignored at runtime, treated as given.
As shown above, the runtime asserts appear as math ops in the graph, generated by the sympy interpreter, resulting in an _assert_scalar call. There is an option to avoid adding these asserts into the graph, by setting `TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1`. This results in the "original" computation graph, with dynamism, and any incorrect inputs will fail on ops during runtime. Further work could go into prettifying the printer, so the majority of the graph isn't guard-related.
Ideally this PR would subsume and remove the recently added [_disable_forced_specializations](https://github.com/pytorch/pytorch/pull/124949) flag, but that flag still handles one additional case of specialization: single-variable equalities where the symbol is solvable for a concrete value: see this [PR](https://github.com/pytorch/pytorch/pull/126925)
This PR doesn't change any behavior around data-dependent errors/unbacked symints yet, that could be further work.
NOTE: will take naming change suggestions for the flag :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127129
Approved by: https://github.com/avikchaudhuri
225 lines
5.0 KiB
Python
225 lines
5.0 KiB
Python
import math
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
from torch.utils._sympy.functions import (
|
|
OpaqueUnaryFn_exp,
|
|
OpaqueUnaryFn_log,
|
|
OpaqueUnaryFn_sqrt,
|
|
)
|
|
|
|
|
|
# The sympy interpretation of operators. It will also sometimes work with
|
|
# plain int/float, but if you do certain operations you will get out a
|
|
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
|
|
# check PythonReferenceAnalysis.
|
|
# NB: For magic methods this needs to use normal magic methods
|
|
# so that test_magic_methods works
|
|
class ReferenceAnalysis:
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
return sympy.sympify(c)
|
|
|
|
@staticmethod
|
|
def or_(a, b):
|
|
return a | b
|
|
|
|
@staticmethod
|
|
def and_(a, b):
|
|
return a & b
|
|
|
|
@staticmethod
|
|
def eq(a, b):
|
|
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
|
|
return sympy.Eq(a, b)
|
|
return a == b
|
|
|
|
@classmethod
|
|
def ne(cls, a, b):
|
|
return cls.not_(cls.eq(a, b))
|
|
|
|
@staticmethod
|
|
def lt(a, b):
|
|
return a < b
|
|
|
|
@staticmethod
|
|
def gt(a, b):
|
|
return a > b
|
|
|
|
@staticmethod
|
|
def le(a, b):
|
|
return a <= b
|
|
|
|
@staticmethod
|
|
def ge(a, b):
|
|
return a >= b
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
assert not isinstance(a, bool)
|
|
return ~a
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return 1 / x
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return x * x
|
|
|
|
@staticmethod
|
|
def mod(x, y):
|
|
ret = abs(x) % abs(y)
|
|
# without check:
|
|
# tracing will fail trying to go through control-flow if x is Proxy()
|
|
if isinstance(x, (int, sympy.Number)) and x < 0:
|
|
ret *= -1
|
|
return ret
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return abs(x)
|
|
|
|
@staticmethod
|
|
def neg(x):
|
|
return -x
|
|
|
|
@staticmethod
|
|
def truediv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def div(a, b):
|
|
return ReferenceAnalysis.truediv(a, b)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
if b == 0:
|
|
return sympy.nan if a == 0 else sympy.zoo
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
result = a / b
|
|
if result.is_finite:
|
|
result = sympy.Integer(result)
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
@staticmethod
|
|
def mul(a, b):
|
|
return a * b
|
|
|
|
@staticmethod
|
|
def sub(a, b):
|
|
return a - b
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
return OpaqueUnaryFn_exp(x)
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return OpaqueUnaryFn_log(x)
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return OpaqueUnaryFn_sqrt(x)
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return a**b
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
# Poorman's version of upcasting in Sympy
|
|
# This won't do for sympy.Expr as the casting does nothing for those
|
|
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
|
result_type = sympy.Float
|
|
else:
|
|
assert a.is_Integer
|
|
assert b.is_Integer
|
|
result_type = sympy.Integer
|
|
return sympy.Min(result_type(a), result_type(b))
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
# Poorman's version of upcasting in Sympy
|
|
# This won't do for sympy.Expr as the casting does nothing for those
|
|
if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite:
|
|
result_type = sympy.Float
|
|
else:
|
|
assert a.is_Integer
|
|
assert b.is_Integer
|
|
result_type = sympy.Integer
|
|
return sympy.Max(result_type(a), result_type(b))
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return sympy.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return sympy.ceiling(x)
|
|
|
|
|
|
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
|
# Python types and is FX traceable. Inheritance here is purely for code
|
|
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
|
|
class PythonReferenceAnalysis(ReferenceAnalysis):
|
|
@staticmethod
|
|
def constant(c, dtype):
|
|
if dtype is torch.int64:
|
|
return int(c)
|
|
elif dtype is torch.double:
|
|
return float(c)
|
|
elif dtype is torch.bool:
|
|
return bool(c)
|
|
else:
|
|
raise AssertionError(f"unrecognized dtype {dtype}")
|
|
|
|
@staticmethod
|
|
def not_(a):
|
|
return torch.sym_not(a)
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
return a // b
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
return a / b
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
raise AssertionError("exp is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
raise AssertionError("log is not valid shape sympy expr")
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return torch._sym_sqrt(x) # type: ignore[attr-defined]
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
return torch.sym_min(a, b)
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
return torch.sym_max(a, b)
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return math.floor(x)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return math.ceil(x)
|