mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The big idea is to add `create_unbacked_symfloat` and `create_unbacked_symint` to ShapeEnv, allowing you to allocate symbolic floats/ints corresponding to data you don't know about at compile time. Then, instead of immediately erroring out when you try to call local_scalar_dense on a FakeTensor, we instead create a fresh symint/symfloat and return that. There a bunch of odds and ends that need to be handled: * A number of `numel` calls converted to `sym_numel` * When we finally return from item(), we need to ensure we actually produce a SymInt/SymFloat when appropriate. The previous binding code assumed that you would have to get a normal Python item. I add a pybind11 binding for Scalar (to PyObject only) and refactor the code to use that. There is some trickiness where you are NOT allowed to go through c10::SymInt if there isn't actually any SymInt involved. See comment. * One of our unit tests tripped an implicit data dependent access which occurs when you pass a Tensor as an argument to a sizes parameter. This is also converted to support symbolic shapes * We now support tracking bare SymInt/SymFloat returns in proxy tensor mode (this was already in symbolic-shapes branch) * Whenever we allocate an unbacked symint, we record the stack trace it was allocated at. These get printed when you attempt data dependent access on the symint (e.g., you try to guard on it) * Subtlety: unbacked symints are not necessarily > 1. I added a test for this. These unbacked symints are not very useful right now as you will almost always immediately raise an error later when you try to guard on them. The next logical step is adding an assertion refinement system that lets ShapeEnv learn facts about unbacked symints so it can do a better job eliding guards that are unnecessary. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90624 Approved by: https://github.com/Skylion007, https://github.com/voznesenskym
118 lines
3.5 KiB
C++
118 lines
3.5 KiB
C++
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_arg_parser.h>
|
|
#include <torch/csrc/utils/python_symnode.h>
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
|
|
bool type_caster<c10::SymInt>::load(py::handle src, bool) {
|
|
if (torch::is_symint(src)) {
|
|
auto node = src.attr("node");
|
|
if (py::isinstance<c10::SymNodeImpl>(node)) {
|
|
value = c10::SymInt(py::cast<c10::SymNode>(node));
|
|
return true;
|
|
}
|
|
|
|
value = c10::SymInt(static_cast<c10::SymNode>(
|
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(node)));
|
|
return true;
|
|
}
|
|
|
|
auto raw_obj = src.ptr();
|
|
if (THPUtils_checkIndex(raw_obj)) {
|
|
value = c10::SymInt{THPUtils_unpackIndex(raw_obj)};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
py::handle type_caster<c10::SymInt>::cast(
|
|
c10::SymInt si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
if (si.is_symbolic()) {
|
|
auto* py_node =
|
|
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
|
if (py_node) {
|
|
// Return the Python directly (unwrap)
|
|
return torch::get_symint_class()(py_node->getPyObj()).release();
|
|
} else {
|
|
// Wrap the C++ into Python
|
|
auto inner = py::cast(si.toSymNodeImpl());
|
|
if (!inner) {
|
|
throw python_error();
|
|
}
|
|
return torch::get_symint_class()(inner).release();
|
|
}
|
|
} else {
|
|
return py::cast(si.as_int_unchecked()).release();
|
|
}
|
|
}
|
|
|
|
bool type_caster<c10::SymFloat>::load(py::handle src, bool) {
|
|
if (torch::is_symfloat(src)) {
|
|
value = c10::SymFloat(static_cast<c10::SymNode>(
|
|
c10::make_intrusive<torch::impl::PythonSymNodeImpl>(src.attr("node"))));
|
|
return true;
|
|
}
|
|
|
|
auto raw_obj = src.ptr();
|
|
if (THPUtils_checkDouble(raw_obj)) {
|
|
value = c10::SymFloat{THPUtils_unpackDouble(raw_obj)};
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
py::handle type_caster<c10::SymFloat>::cast(
|
|
c10::SymFloat si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
if (si.is_symbolic()) {
|
|
// TODO: generalize this to work with C++ backed class
|
|
auto* py_node =
|
|
dynamic_cast<torch::impl::PythonSymNodeImpl*>(si.toSymNodeImpl().get());
|
|
TORCH_INTERNAL_ASSERT(py_node);
|
|
return torch::get_symfloat_class()(py_node->getPyObj()).release();
|
|
} else {
|
|
return py::cast(si.as_float_unchecked()).release();
|
|
}
|
|
}
|
|
|
|
bool type_caster<c10::Scalar>::load(py::handle src, bool) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
0, "pybind11 loading for c10::Scalar NYI (file a bug if you need it)");
|
|
}
|
|
|
|
py::handle type_caster<c10::Scalar>::cast(
|
|
const c10::Scalar& scalar,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
if (scalar.isIntegral(/*includeBool*/ false)) {
|
|
// We have to be careful here; we cannot unconditionally route through
|
|
// SymInt because integer data from Tensors can easily be MIN_INT or
|
|
// very negative, which conflicts with the allocated range.
|
|
if (scalar.isSymbolic()) {
|
|
return py::cast(scalar.toSymInt()).release();
|
|
} else {
|
|
return py::cast(scalar.toLong()).release();
|
|
}
|
|
} else if (scalar.isFloatingPoint()) {
|
|
// This isn't strictly necessary but we add it for symmetry
|
|
if (scalar.isSymbolic()) {
|
|
return py::cast(scalar.toSymFloat()).release();
|
|
} else {
|
|
return py::cast(scalar.toDouble()).release();
|
|
}
|
|
} else if (scalar.isBoolean()) {
|
|
return py::cast(scalar.toBool()).release();
|
|
} else if (scalar.isComplex()) {
|
|
return py::cast(scalar.toComplexDouble()).release();
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(0, "unrecognized scalar type ", scalar.type());
|
|
}
|
|
}
|
|
|
|
} // namespace detail
|
|
} // namespace pybind11
|