pytorch/torch/csrc/utils/python_symnode.h
Pian Pawakapan 4c007073e6 [dynamic shapes] DynamicInts prototype (#162194)
Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.

Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3))  # same as f(3)
```

Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
    f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.

Differential Revision: D81698719

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
2025-09-18 23:26:28 +00:00

333 lines
10 KiB
C++

#pragma once
#include <c10/core/SafePyObject.h>
#include <c10/core/SymNodeImpl.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/utils/pybind.h>
namespace torch {
TORCH_PYTHON_API py::handle get_symint_class();
TORCH_PYTHON_API py::handle get_symfloat_class();
TORCH_PYTHON_API py::handle get_symbool_class();
TORCH_PYTHON_API py::handle get_dynint_class();
// NB: These functions must not be called too early, otherwise torch not setup.
// Alternate design is to have torch "register" the object to us
inline bool is_symint(py::handle obj) {
return py::isinstance(obj, get_symint_class());
}
inline bool is_symfloat(py::handle obj) {
return py::isinstance(obj, get_symfloat_class());
}
inline bool is_symbool(py::handle obj) {
return py::isinstance(obj, get_symbool_class());
}
inline bool is_dynint(py::handle obj) {
return py::isinstance(obj, get_dynint_class());
}
namespace impl {
// This c10::SymNodeImpl simply backends to a Python object that
// implements the API. The Python object is the source of truth,
// this is just an adapter so C++ calls can get to the object.
class PythonSymNodeImpl : public c10::SymNodeImpl {
public:
PythonSymNodeImpl(py::object pyobj) : c10::SymNodeImpl() {
pyobj_ = std::make_shared<c10::SafePyObject>(
pyobj.release().ptr(), getPyInterpreter());
}
c10::SymNode wrap_int(int64_t num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_int")(num);
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
}
c10::SymNode wrap_float(double num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_float")(num);
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
}
c10::SymNode wrap_bool(bool num) override {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr("wrap_bool")(num);
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r));
}
#define TORCH_SYMNODE_SIZES_STRIDES(n) \
c10::SymNode n( \
c10::ArrayRef<c10::SymNode> sizes, c10::ArrayRef<c10::SymNode> strides) \
override { \
py::gil_scoped_acquire acquire; \
auto r = getPyObj().attr(#n)(sizes, strides); \
return c10::make_intrusive<PythonSymNodeImpl>(std::move(r)); \
}
// clang-format off
TORCH_SYMNODE_SIZES_STRIDES(is_contiguous)
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_2d)
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_contiguous_3d)
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_2d)
TORCH_SYMNODE_SIZES_STRIDES(is_channels_last_strides_3d)
TORCH_SYMNODE_SIZES_STRIDES(is_non_overlapping_and_dense)
// clang-format on
#undef TORCH_SYMNODE_SIZES_STRIDES
bool bool_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("bool_")().is(py::handle(Py_True));
}
bool is_int() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_int")().is(py::handle(Py_True));
}
bool is_float() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_float")().is(py::handle(Py_True));
}
bool is_bool() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_bool")().is(py::handle(Py_True));
}
bool is_nested_int() const override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("is_nested_int")().is(py::handle(Py_True));
}
bool has_hint() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("has_hint")().is(py::handle(Py_True));
}
int64_t guard_int(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_int")(file, line).cast<int64_t>();
}
double guard_float(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_float")(file, line).cast<double>();
}
bool guard_bool(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_bool")(file, line).cast<bool>();
}
bool expect_true(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("expect_true")(file, line).cast<bool>();
}
bool expect_size(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("expect_size")(file, line).cast<bool>();
}
bool guard_size_oblivious(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_size_oblivious")(file, line).cast<bool>();
}
bool guard_or_false(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_or_false")(file, line).cast<bool>();
}
bool statically_known_true(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("statically_known_true")(file, line).cast<bool>();
}
bool guard_or_true(const char* file, int64_t line) override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("guard_or_true")(file, line).cast<bool>();
}
int64_t int_() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("int_")().cast<int64_t>();
}
std::optional<int64_t> maybe_as_int() override {
py::gil_scoped_acquire acquire;
const auto& r = getPyObj().attr("maybe_as_int")();
if (r.is_none()) {
return std::nullopt;
} else {
return r.cast<int64_t>();
}
}
std::string str() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("str")().cast<std::string>();
}
std::string _graph_repr() override {
py::gil_scoped_acquire acquire;
return getPyObj().attr("_graph_repr")().cast<std::string>();
}
c10::SymNode dispatch_sym_ite_(
const char* fname,
const c10::SymNode& other,
const c10::SymNode& third) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
auto pthird = dynamic_cast<PythonSymNodeImpl*>(third.get());
TORCH_CHECK(pother);
TORCH_CHECK(pthird);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj(), pthird->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode dispatch_common_(const char* fname, const c10::SymNode& other) {
auto pother = dynamic_cast<PythonSymNodeImpl*>(other.get());
TORCH_CHECK(pother);
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)(pother->getPyObj());
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode dispatch_common_(const char* fname) {
py::gil_scoped_acquire acquire;
auto r = getPyObj().attr(fname)();
return c10::make_intrusive<PythonSymNodeImpl>(r);
}
c10::SymNode add(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sub(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode mul(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode truediv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode float_truediv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode int_truediv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode pow(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode float_pow(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode pow_by_natural(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode floordiv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode int_floordiv(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode mod(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode eq(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode ne(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode gt(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode lt(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode le(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode ge(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sym_min(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sym_max(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sym_and(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sym_or(const c10::SymNode& other) override {
return dispatch_common_(__func__, other);
}
c10::SymNode sym_ite(const c10::SymNode& other, const c10::SymNode& third)
override {
return dispatch_sym_ite_(__func__, other, third);
}
c10::SymNode sym_not() override {
return dispatch_common_(__func__);
}
c10::SymNode ceil() override {
return dispatch_common_(__func__);
}
c10::SymNode floor() override {
return dispatch_common_(__func__);
}
c10::SymNode neg() override {
return dispatch_common_(__func__);
}
c10::SymNode clone() override {
return dispatch_common_(__func__);
}
c10::SymNode sym_float() override {
return dispatch_common_(__func__);
}
py::handle getPyObj() const {
return py::handle(pyobj_->ptr(getPyInterpreter()));
}
std::shared_ptr<c10::SafePyObject> pyobj_ = nullptr;
};
} // namespace impl
} // namespace torch