Initial prototype for dynamic int inputs, allows users to run with `torch.compile(f)(DynamicInt(4))`, compiling dynamically and using the underlying hint at runtime.
Current behavior:
- Also works in eager (mostly by subclassing int), as scalar input to torch functions, or numpy/math/etc. For example, `x = DynamicInt(3); torch.randn(x); torch.add(y, z, alpha=x); np.arange(x)` all act as if x = 3.
- Behavior for arithmetic ops is to return new DynamicInts rather than static ints; `DynamicInt(3) * 2 = DynamicInt(6)`. This is via SymNode magic methods, but coverage might not be 100% - for example, I had to explicitly override floordiv to avoid int casting. This is not necessarily the case for non-magic method ops (e.g. `math.cos(x)`). The alternative here is to int cast on all operations, but I opted for this for dynamism propagation in non-compiled regions.
- Doesn't ban fullgraph=False; DynamicInt objects might be leaked back to the user, but I guess this is fine, because they can be casted to ints when needed?
- Dynamo only allocates one symbol per DynamicInt; specifying the same DynamicInt for multiple inputs leads to input deduplication, and a guard installed.
- We don't raise on int specialization (in allowlist/maybe_mark_dynamic style) - but an easy change if needed.
- DynamicInts as nn.Module attributes are handled.
- We don't guard on the DynamicInt id, e.g. users can do the following without recompiling (maybe we should guard?)
```python
x = DynamicInt(4)
f(x)
f(1)
f(DynamicInt(3)) # same as f(3)
```
Follow-up work:
- Specifying shape constraints, either at the int-level, e.g.
```python
DynamicInt(64, name="s0", constraints=["s0 % 32 == 0", "s0 <= 1024"]
```
or at the compilation level, e.g. something like
```python
s0 = DynamicInt(64, name="s0")
s1 = DynamicInt(128, name="s1")
with some_compiler_config.dynamic_int_constraints(["s1 == 2*s0", "s0 % 32 == 0"]):
f(s0, s1)
```
This should subsume the need for specifying derived SymInts?
- SymFloat support - currently it seems backed floats are specialized by the tensorify float pass, and there's no handling in inductor.
- Propagating dynamism in tensor constructors, e.g. `x = DynamicInt(4); torch.randn(x)` could annotate `_dynamo_dynamic_indices`.
Differential Revision: D81698719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162194
Approved by: https://github.com/bobrenjc93
Fix static `py::object`s with `py::gil_safe_call_once_and_store`.
The following code will leak a `py::object` which will call its destructor when shutdown the program. The destructor will call `Py_DECREF(obj.m_ptr)` which may raise a segmentation fault.
```c++
void func() {
static py::object obj = py::module_::import("foo").attr("bar");
...
}
```
The correct code is to use raw pointers rather than the instance.
```c++
void func() {
static py::object* obj_ptr = new py::object{py::module_::import("foo").attr("bar")};
py::object obj = *obj_ptr;
...
}
```
This PR uses the `py::gil_safe_call_once_and_store` function from `pybind11`, which can run arbitrary initialization code only once under the Python GIL thread safely.
```c++
void func() {
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> storage;
py::object obj = storage
.call_once_and_store_result(
[]() -> py::object {
return py::module_::import("foo").attr("bar");
}
)
.get_stored();
...
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130341
Approved by: https://github.com/ezyang, https://github.com/malfet
Fix static `py::object`s with `py::gil_safe_call_once_and_store`.
The following code will leak a `py::object` which will call its destructor when shutdown the program. The destructor will call `Py_DECREF(obj.m_ptr)` which may raise a segmentation fault.
```c++
void func() {
static py::object obj = py::module_::import("foo").attr("bar");
...
}
```
The correct code is to use raw pointers rather than the instance.
```c++
void func() {
static py::object* obj_ptr = new py::object{py::module_::import("foo").attr("bar")};
py::object obj = *obj_ptr;
...
}
```
This PR uses the `py::gil_safe_call_once_and_store` function from `pybind11`, which can run arbitrary initialization code only once under the Python GIL thread safely.
```c++
void func() {
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> storage;
py::object obj = storage
.call_once_and_store_result(
[]() -> py::object {
return py::module_::import("foo").attr("bar");
}
)
.get_stored();
...
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130341
Approved by: https://github.com/ezyang
This refactor was prompted by challenges handling mixed int/float
operations in C++. A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different
approach.
The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode. This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods. This has a number of
knock on effects.
- We no longer have C++ classes to bind to Python. Instead, we take an
entirely new approach to our Python API, where we have a SymInt/SymFloat
class defined entirely in Python, which hold a SymNode (which corresponds
to the C++ SymNode). However, SymNode is not pybind11-bound; instead,
it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
when it goes into C++. This implies a userland rename.
In principle, it is also possible for the canonical implementation of SymNode
to be written in C++, and then bound to Python with pybind11 (we have
this code, although it is commented out.) However, I did not implement
this as we currently have no C++ implementations of SymNode.
Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
code needs to know how to find these classes. Currently, this is done
just by manually importing torch and getting the attributes.
- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
__torch_dispatch__ works.
Some miscellaneous improvements:
- SymInt now has a constructor that takes SymNode. Note that this
constructor is ambiguous if you pass in a subclass of SymNode,
so an explicit downcast is necessary. This means toSymFloat/toSymInt
are no more. This is a mild optimization as it means rvalue reference
works automatically.
- We uniformly use the caster for c10::SymInt/SymFloat, rather than
going the long way via the SymIntNode/SymFloatNode.
- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
functions, pretty sure this doesn't do anything.
- guard_int is now a free function, since to guard on an int you cannot
assume the method exists. A function can handle both int and SymInt
inputs.
- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
plain methods; this is to help avoid confusion between the two types.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411