mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add support for sym_ite (#111440)
This PR supports sym_ite. This is useful for converting SymBool to SymInt in e.g. #109916. Internally, it uses sympy.Piecewise. We cannot use sympy.ITE because it expects the arguments and output all to be boolean type but we want return SymInt type when converting a SymBool to SymInt. So we use sympy.Piecewise to denote the symbolic relationship. Note that this pr uses the range analysis for sympy.Piecewise implemented in https://github.com/pytorch/pytorch/blob/main/torch/utils/_sympy/value_ranges.py. Test Plan: See added test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/111440 Approved by: https://github.com/ezyang
This commit is contained in:
parent
09040f6fbb
commit
f3d02d9ae6
|
|
@ -97,6 +97,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
|
|||
virtual SymNode sym_not() {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
virtual SymNode sym_ite(const SymNode& then_val, const SymNode& else_val) {
|
||||
TORCH_CHECK(false, "NYI");
|
||||
};
|
||||
// NB: self is ignored here, only the arguments are used
|
||||
virtual SymNode is_contiguous(
|
||||
ArrayRef<SymNode> sizes,
|
||||
|
|
|
|||
|
|
@ -967,6 +967,7 @@ coverage_ignore_functions = [
|
|||
"parallel_and",
|
||||
"parallel_or",
|
||||
"sym_sqrt",
|
||||
"sym_ite",
|
||||
"sympy_is_channels_last_contiguous_2d",
|
||||
"sympy_is_channels_last_contiguous_3d",
|
||||
"sympy_is_channels_last_strides_2d",
|
||||
|
|
|
|||
|
|
@ -706,6 +706,7 @@ Symbolic Numbers
|
|||
sym_max
|
||||
sym_min
|
||||
sym_not
|
||||
sym_ite
|
||||
|
||||
Export Path
|
||||
-------------
|
||||
|
|
|
|||
|
|
@ -423,6 +423,50 @@ class TestPySymInt(TestCase):
|
|||
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""")
|
||||
|
||||
def test_sym_ite(self):
|
||||
shape_env = ShapeEnv()
|
||||
t = create_symint(shape_env, 5)
|
||||
f = create_symint(shape_env, 4)
|
||||
b1 = True
|
||||
r1 = torch.sym_ite(b1, t, f)
|
||||
self.assertTrue(r1 is t)
|
||||
b2 = False
|
||||
r2 = torch.sym_ite(b2, t, f)
|
||||
self.assertTrue(r2 is f)
|
||||
b3 = t == 5
|
||||
r3 = torch.sym_ite(b3, t, f)
|
||||
self.assertEqual(len(shape_env.guards), 0)
|
||||
self.assertEqual(r3, 5)
|
||||
self.assertEqual(type(t), type(r3))
|
||||
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""")
|
||||
b4 = f == 5
|
||||
r4 = torch.sym_ite(b4, t, f)
|
||||
self.assertEqual(len(shape_env.guards), 1)
|
||||
self.assertEqual(r4, 4)
|
||||
self.assertEqual(type(f), type(r4))
|
||||
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""")
|
||||
|
||||
def test_tracing_sym_ite(self):
|
||||
def f(x):
|
||||
b = x.shape[0] == 5
|
||||
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
|
||||
return ret
|
||||
|
||||
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
|
||||
self.assertEqual(len(gm.shape_env.guards), 0)
|
||||
self.assertExpectedInline(gm.code.strip(), """\
|
||||
def forward(self, x_1):
|
||||
sym_size = torch.ops.aten.sym_size(x_1, 0)
|
||||
eq = sym_size == 5
|
||||
sym_size_1 = torch.ops.aten.sym_size(x_1, 1); x_1 = None
|
||||
sym_ite = torch.sym_ite(eq, sym_size, sym_size_1); eq = sym_size = sym_size_1 = None
|
||||
return sym_ite""")
|
||||
r1 = gm(torch.ones(4, 5))
|
||||
self.assertIsInstance(r1, int)
|
||||
self.assertEqual(r1, 5)
|
||||
r2 = gm(torch.ones(5, 4))
|
||||
self.assertIsInstance(r2, int)
|
||||
self.assertEqual(r2, 5)
|
||||
|
||||
def test_int_conversion(self):
|
||||
shape_env = ShapeEnv()
|
||||
|
|
@ -684,7 +728,8 @@ class TestSymNumberMagicMethods(TestCase):
|
|||
|
||||
@parametrize("fn", list(symbolic_shapes.magic_methods.keys()))
|
||||
def test_bool_method(self, fn):
|
||||
if fn not in symbolic_shapes.bool_magic_methods:
|
||||
# sym_ite has its own tests
|
||||
if fn not in symbolic_shapes.bool_magic_methods or fn == "sym_ite":
|
||||
self.skipTest(f"{fn} is non-bool")
|
||||
|
||||
is_unary_fn = fn in symbolic_shapes.unary_magic_methods
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ __all__ = [
|
|||
'set_float32_matmul_precision', 'get_float32_matmul_precision',
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
'SymBool', 'sym_not', 'unravel_index',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
|
||||
'export', 'autocast', 'cond',
|
||||
]
|
||||
|
||||
|
|
@ -390,6 +390,9 @@ class SymBool:
|
|||
def __sym_not__(self) -> "SymBool":
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __sym_ite__(self, then_val, else_val):
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
def __eq__(self, other) -> builtins.bool:
|
||||
raise AssertionError("type stub not overridden")
|
||||
|
||||
|
|
@ -456,6 +459,12 @@ def sym_min(a, b):
|
|||
return b.__sym_min__(a)
|
||||
return builtins.min(a, b) # type: ignore[operator]
|
||||
|
||||
def sym_ite(b, t, f):
|
||||
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
||||
if isinstance(b, SymBool):
|
||||
return b.__sym_ite__(t, f)
|
||||
return t if b else f
|
||||
|
||||
# Check to see if we can load C extensions, and if not provide some guidance
|
||||
# on what the problem might be.
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -145,6 +145,19 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
|||
return getPyObj().attr("str")().cast<std::string>();
|
||||
}
|
||||
|
||||
c10::SymNode dispatch_sym_ite_(
|
||||
const char* fname,
|
||||
const c10::SymNode& other,
|
||||
const c10::SymNode& third) {
|
||||
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
||||
auto pthird = dynamic_cast<PythonSymNodeImpl*>(third.get());
|
||||
TORCH_CHECK(pother);
|
||||
TORCH_CHECK(pthird);
|
||||
py::gil_scoped_acquire acquire;
|
||||
auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj());
|
||||
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
||||
}
|
||||
|
||||
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
|
||||
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
||||
TORCH_CHECK(pother);
|
||||
|
|
@ -226,6 +239,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl {
|
|||
return dispatch_common_(__func__, other);
|
||||
}
|
||||
|
||||
c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third)
|
||||
override {
|
||||
return dispatch_sym_ite_(__func__, other, third);
|
||||
}
|
||||
|
||||
c10::SymNode sym_not() override {
|
||||
return dispatch_common_(__func__);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ from torch import ( # noqa: F401
|
|||
sym_max,
|
||||
sym_min,
|
||||
sym_not,
|
||||
sym_ite,
|
||||
SymBool,
|
||||
SymFloat,
|
||||
SymInt,
|
||||
|
|
@ -962,6 +963,9 @@ class SymNode:
|
|||
def sym_max(self, other) -> "SymNode": # noqa: F811
|
||||
return self._sym_max(other) # type: ignore[attr-defined]
|
||||
|
||||
def sym_ite(self, then_val, else_val) -> "SymNode":
|
||||
return self._sym_ite(then_val, else_val)
|
||||
|
||||
def sym_sqrt(self) -> "SymNode": # noqa: F811
|
||||
return self._sym_sqrt() # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -1181,6 +1185,7 @@ magic_methods = {
|
|||
'neg': lambda a: -a,
|
||||
'sym_min': lambda a, b: sympy.Min(a, b),
|
||||
'sym_max': lambda a, b: sympy.Max(a, b),
|
||||
'sym_ite': lambda a, t, f: sympy.Piecewise((t, a), (f, True)),
|
||||
'sym_sqrt': lambda a: sympy.sqrt(a),
|
||||
'abs': lambda a: sympy.Abs(a),
|
||||
}
|
||||
|
|
@ -1318,13 +1323,13 @@ unary_magic_methods = {
|
|||
|
||||
# Most methods are only registered on SymInt and SymFloat
|
||||
# Some methods are only be registered on SymBool
|
||||
only_bool_magic_methods = {"and", "or", "sym_not"}
|
||||
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
|
||||
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
|
||||
also_bool_magic_methods = {"eq"}
|
||||
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
|
||||
|
||||
magic_methods_on_math = {"ceil", "floor"}
|
||||
magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not"}
|
||||
magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not", "sym_ite"}
|
||||
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
|
||||
|
||||
def method_to_operator(method):
|
||||
|
|
@ -1463,6 +1468,36 @@ def _make_node_magic(method, func):
|
|||
|
||||
if method in unary_magic_methods:
|
||||
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
|
||||
elif method == "sym_ite":
|
||||
|
||||
def sym_ite_impl(pred_node, then_node, else_node):
|
||||
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
||||
if SYM_FUNCTION_MODE:
|
||||
return to_node(
|
||||
pred_node,
|
||||
_handle_sym_dispatch(
|
||||
sym_ite,
|
||||
(wrap_node(pred_node), wrap_node(then_node), wrap_node(else_node)), {}
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
out = func(pred_node.expr, then_node.expr, else_node.expr)
|
||||
except Exception:
|
||||
log.warning("failed to eval %s(%s, %s, %s)", method, pred_node.expr, then_node.expr, else_node.expr)
|
||||
raise
|
||||
|
||||
out = safe_expand(out)
|
||||
fx_node, _ = pred_node.shape_env.create_fx_call_function(
|
||||
sym_ite,
|
||||
(
|
||||
pred_node.fx_node,
|
||||
then_node.fx_node,
|
||||
else_node.fx_node
|
||||
)
|
||||
)
|
||||
return SymNode(out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node)
|
||||
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
|
||||
else:
|
||||
setattr(SymNode, f"_{method_attr}", binary_magic_impl)
|
||||
|
||||
|
|
@ -1602,6 +1637,19 @@ def _make_user_magic(method, user_type):
|
|||
|
||||
if method in unary_magic_methods:
|
||||
setattr(user_type, f"__{method}__", unary_magic_impl)
|
||||
elif method == "sym_ite":
|
||||
|
||||
def sym_ite_magic_impl(pred, then_val, else_val):
|
||||
pred_node = pred.node
|
||||
then_node = to_node(pred_node, then_val)
|
||||
else_node = to_node(pred_node, else_val)
|
||||
if then_node is NotImplemented or else_node is NotImplemented:
|
||||
return NotImplemented
|
||||
assert isinstance(then_node, SymNode) and isinstance(else_node, SymNode) and then_node.pytype == else_node.pytype
|
||||
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
|
||||
return get_constant(ret) if ret.node.is_constant() else ret
|
||||
|
||||
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
|
||||
else:
|
||||
setattr(user_type, f"__{method}__", binary_magic_impl)
|
||||
if method in reflectable_magic_methods:
|
||||
|
|
|
|||
|
|
@ -276,6 +276,7 @@ try:
|
|||
torch.sym_float: lift(ops.to_real),
|
||||
torch.sym_max: lift(ops.max),
|
||||
torch.sym_min: lift(ops.min),
|
||||
torch.sym_ite: lift(lambda b, t, f: t if b else f),
|
||||
sym_sqrt: lift(ops.sqrt),
|
||||
# Not lifted because we only use this function as a
|
||||
# marker for adding the expression as validator input.
|
||||
|
|
|
|||
|
|
@ -221,6 +221,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.sym_not,
|
||||
torch.sym_ite,
|
||||
torch.sym_constrain_range,
|
||||
torch.sym_constrain_range_for_size,
|
||||
torch.tril_indices,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user