mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
At a high level, the idea behind this PR is: * Make it clearer what the promotion and int/float rules for various Sympy operations are. Operators that previously were polymorphic over int/float are now split into separate operators for clarity. We never do mixed int/float addition/multiplication etc in sympy, instead, we always promote to the appropriate operator. (However, equality is currently not done correctly.) * Enforce strict typing on ValueRanges: if you have a ValueRange for a float, the lower and upper MUST be floats, and so forth for integers. The story begins in **torch/utils/_sympy/functions.py**. Here, I make some changes to how we represent certain operations in sympy expressions: * FloorDiv now only supports integer inputs; to do float floor division, do a truediv and then a trunc. Additionally, we remove the divide out addition by gcd optimization, because sympy gcd is over fields and is willing to generate rationals (but rationals are bad for ValueRange strict typing). * ModularIndexing, LShift, RShift now assert they are given integer inputs. * Mod only supports integer inputs; eventually we will support FloatMod (left for later work, when we build out Sympy support for floating operations). Unfortunately, I couldn't assert integer inputs here, because of a bad interaction with sympy's inequality solver that is used by the offline solver * TrueDiv is split into FloatTrueDiv and IntTrueDiv. This allows for us to eventually generate accurate code for Python semantics IntTrueDiv, which is written in a special way to preserve precision when the inputs are >= 2**53 beyond what first coercing the integer to floats and then doing true division. * Trunc is split to TruncToFloat and TruncToInt. * Round is updated to return a float, not an int, making it consistent with the round op handler in Inductor. To get Python-style conversion to int, we call TruncToInt on the result. * RoundDecimal updated to consistently only ever return a float * Add ToFloat for explicit coercion to float (required so we can enforce strict ValueRanges typing) In **torch/__init__.py**, we modify SymInt and SymFloat to appropriately call into new bindings that route to these refined sympy operations. Also, we modify `torch.sym_min` and `torch.sym_max` to have promotion semantics (if one argument is a float, the return result is always a float), making them inconsistent with builtins.min/max, but possible to do type analysis without runtime information. We also need to introduce some new op handlers in **torch/_inductor/ops_handler.py**: * `to_int` for truncation to int64, directly corresponding to TruncToInt; this can be implemented by trunc and dtype, but with a dedicated handler it is more convenient for roundtripping in Sympy * `int_truediv` for Python-style integer true division, which has higher precision than casting to floats and then running `truediv` These changes have consequences. First, we need to make some administrative changes: * Actually wire up these Sympy functions from SymInt/SymFloat in **torch/fx/experimental/sym_node.py**, including the new promotion rules (promote2) * Add support for new Sympy functions in **torch/utils/_sympy/interp.py**, **torch/utils/_sympy/reference.py** * In particular, in torch.utils._sympy.reference, we have a strong preference to NOT do nontrivial compute, instead, everything in ops handler should map to a singular sympy function * TODO: I chose to roundtrip mod back to our Mod function, but I think I'm going to have to deal with the C/Python inconsistency this to fix tests here * Add printer support for the Sympy functions in **torch/_inductor/codegen/common.py**, **torch/_inductor/codegen/cpp_utils.py**, **torch/_inductor/codegen/triton.py**. `int_truediv` and mixed precision equality is currently not implemented soundly, so we will lose precision in codegen for large values. TODO: The additions here are not exhaustive yet * Update ValueRanges logic to use new sympy functions in **torch/utils/_sympy/value_ranges.py**. In general, we prefer to use the new Sympy function rather than try to roll things by hand, which is what was done previously for many VR analysis functions. In **torch/fx/experimental/symbolic_shapes.py** we need to make some symbolic reasoning adjustments: * Avoid generation of rational subexpressions by removing simplification of `x // y` into `floor(x / y)`. This simplification then triggers an addition simplification rule `(x + y) / c --> x / c + y / c` which is bad because x / c is a rational number now * `_assert_bound_is_rational` is no more, we no longer generate rational bounds * Don't intersect non-int value ranges with the `int_range` * Support more sympy Functions for guard SYMPY_INTERP * Assert the type of value range is consistent with the variable type The new asserts uncovered necessary bug fixes: * **torch/_inductor/codegen/cpp.py**, **torch/_inductor/select_algorithm.py**, **torch/_inductor/sizevars.py** - Ensure Wild/Symbol manually allocated in Inductor is marked `is_integer` so it's accepted to build expressions * **torch/_inductor/utils.py** - make sure you actually pass in sympy.Expr to these functions * **torch/_inductor/ir.py** - make_contiguous_strides_for takes int/SymInt, not sympy.Expr! * **torch/export/dynamic_shapes.py** - don't use infinity to represent int ranges, instead use sys.maxsize - 1 Because of the removal of some symbolic reasoning that produced rationals, some of our symbolic reasoning has gotten worse and we are unable to simplify some guards. Check the TODO at **test/test_proxy_tensor.py** Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/126905 Approved by: https://github.com/xadupre, https://github.com/lezcano
309 lines
9.2 KiB
C++
309 lines
9.2 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <c10/core/SymNodeImpl.h>
|
|
|
|
#include <torch/csrc/PyInterpreter.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
|
|
namespace torch {
|
|
|
|
TORCH_PYTHON_API py::handle get_symint_class();
|
|
TORCH_PYTHON_API py::handle get_symfloat_class();
|
|
TORCH_PYTHON_API py::handle get_symbool_class();
|
|
|
|
// NB: These functions must not be called too early, otherwise torch not setup.
|
|
// Alternate design is to have torch "register" the object to us
|
|
inline bool is_symint(py::handle obj) {
|
|
return py::isinstance(obj, get_symint_class());
|
|
}
|
|
inline bool is_symfloat(py::handle obj) {
|
|
return py::isinstance(obj, get_symfloat_class());
|
|
}
|
|
inline bool is_symbool(py::handle obj) {
|
|
return py::isinstance(obj, get_symbool_class());
|
|
}
|
|
|
|
namespace impl {
|
|
|
|
// This c10::SymNodeImpl simply backends to a Python object that
|
|
// implements the API. The Python object is the source of truth,
|
|
// this is just an adapter so C++ calls can get to the object.
|
|
class PythonSymNodeImpl : public c10::SymNodeImpl {
|
|
public:
|
|
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
|
|
pyobj_ = std::make_shared<c10::SafePyObject>(
|
|
pyobj.release().ptr(), getPyInterpreter());
|
|
};
|
|
|
|
c10::SymNode wrap_int(int64_t num) override {
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr("wrap_int")(num);
|
|
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
|
|
}
|
|
|
|
c10::SymNode wrap_float(double num) override {
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr("wrap_float")(num);
|
|
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
|
|
}
|
|
|
|
c10::SymNode wrap_bool(bool num) override {
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr("wrap_bool")(num);
|
|
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
|
|
}
|
|
|
|
#define TORCH_SYMNODE_SIZES_STRIDES(n) \
|
|
c10::SymNode n( \
|
|
c10::ArrayRef<c10::SymNode> sizes, c10::ArrayRef<c10::SymNode> strides) \
|
|
override { \
|
|
py::gil_scoped_acquire acquire; \
|
|
auto r = getPyObj().attr(#n)(sizes, strides); \
|
|
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); \
|
|
}
|
|
|
|
// clang-format off
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_contiguous)
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
|
|
TORCH_SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
|
|
// clang-format on
|
|
|
|
#undef TORCH_SYMNODE_SIZES_STRIDES
|
|
|
|
bool bool_() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("bool_")().is(py::handle(Py_True));
|
|
}
|
|
|
|
bool is_int() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("is_int")().is(py::handle(Py_True));
|
|
}
|
|
|
|
bool is_float() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("is_float")().is(py::handle(Py_True));
|
|
}
|
|
|
|
bool is_bool() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("is_bool")().is(py::handle(Py_True));
|
|
}
|
|
|
|
bool is_nested_int() const override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("is_nested_int")().is(py::handle(Py_True));
|
|
}
|
|
|
|
bool has_hint() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("has_hint")().is(py::handle(Py_True));
|
|
}
|
|
|
|
int64_t guard_int(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
|
|
}
|
|
|
|
double guard_float(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("guard_float")(file, line).cast<double>();
|
|
}
|
|
|
|
bool guard_bool(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("guard_bool")(file, line).cast<bool>();
|
|
}
|
|
|
|
bool expect_true(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("expect_true")(file, line).cast<bool>();
|
|
}
|
|
|
|
bool expect_size(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("expect_size")(file, line).cast<bool>();
|
|
}
|
|
|
|
bool guard_size_oblivious(const char* file, int64_t line) override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("guard_size_oblivious")(file, line).cast<bool>();
|
|
}
|
|
|
|
int64_t int_() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("int_")().cast<int64_t>();
|
|
}
|
|
|
|
std::optional<int64_t> maybe_as_int() override {
|
|
py::gil_scoped_acquire acquire;
|
|
const auto& r = getPyObj().attr("maybe_as_int")();
|
|
if (r.is_none()) {
|
|
return c10::nullopt;
|
|
} else {
|
|
return r.cast<int64_t>();
|
|
}
|
|
}
|
|
|
|
std::string str() override {
|
|
py::gil_scoped_acquire acquire;
|
|
return getPyObj().attr("str")().cast<std::string>();
|
|
}
|
|
|
|
c10::SymNode dispatch_sym_ite_(
|
|
const char* fname,
|
|
const c10::SymNode& other,
|
|
const c10::SymNode& third) {
|
|
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
|
auto pthird = dynamic_cast<PythonSymNodeImpl*>(third.get());
|
|
TORCH_CHECK(pother);
|
|
TORCH_CHECK(pthird);
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj());
|
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
|
}
|
|
|
|
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
|
|
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
|
|
TORCH_CHECK(pother);
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr(fname)(pother->getPyObj());
|
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
|
}
|
|
|
|
c10::SymNode dispatch_common_(const char* fname) {
|
|
py::gil_scoped_acquire acquire;
|
|
auto r = getPyObj().attr(fname)();
|
|
return c10::make_intrusive<PythonSymNodeImpl>(r);
|
|
}
|
|
|
|
c10::SymNode add(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode sub(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode mul(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode truediv(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode float_truediv(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode int_truediv(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode pow(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode float_pow(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode pow_by_natural(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode floordiv(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode int_floordiv(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode mod(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode eq(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode ne(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode gt(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode lt(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode le(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode ge(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode sym_min(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
c10::SymNode sym_max(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode sym_and(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode sym_or(const c10::SymNode& other) override {
|
|
return dispatch_common_(__func__, other);
|
|
}
|
|
|
|
c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third)
|
|
override {
|
|
return dispatch_sym_ite_(__func__, other, third);
|
|
}
|
|
|
|
c10::SymNode sym_not() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
c10::SymNode ceil() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
c10::SymNode floor() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
c10::SymNode neg() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
c10::SymNode clone() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
c10::SymNode sym_float() override {
|
|
return dispatch_common_(__func__);
|
|
}
|
|
|
|
py::handle getPyObj() const {
|
|
return py::handle(pyobj_->ptr(getPyInterpreter()));
|
|
}
|
|
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
|
|
};
|
|
|
|
} // namespace impl
|
|
} // namespace torch
|