mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., https://github.com/pytorch/pytorch/pull/90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.) This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below. * I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix. * ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from https://github.com/pytorch/pytorch/pull/85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR. * (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides` * On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use. * There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable. TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/92149 Approved by: https://github.com/albanD, https://github.com/Skylion007
359 lines
11 KiB
C++
359 lines
11 KiB
C++
#pragma once
|
|
|
|
#include <torch/csrc/python_headers.h>
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/core/jit_type_base.h>
|
|
#include <c10/util/irange.h>
|
|
#include <c10/util/variant.h>
|
|
#include <pybind11/pybind11.h>
|
|
#include <pybind11/stl.h>
|
|
|
|
#include <torch/csrc/Device.h>
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/Generator.h>
|
|
#include <torch/csrc/MemoryFormat.h>
|
|
#include <torch/csrc/utils/tensor_memoryformats.h>
|
|
|
|
#include <stdexcept>
|
|
#include <utility>
|
|
|
|
namespace py = pybind11;
|
|
|
|
// This makes intrusive_ptr to be available as a custom pybind11 holder type,
|
|
// see
|
|
// https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers
|
|
PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr<T>, true);
|
|
|
|
PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr<T>);
|
|
PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr<T>, true);
|
|
|
|
namespace pybind11 {
|
|
namespace detail {
|
|
|
|
// torch.Tensor <-> at::Tensor conversions (without unwrapping)
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<at::Tensor> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor"));
|
|
|
|
bool load(handle src, bool);
|
|
|
|
static handle cast(
|
|
const at::Tensor& src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
};
|
|
|
|
// torch._StorageBase <-> at::Storage
|
|
template <>
|
|
struct type_caster<at::Storage> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::Storage, _("torch.StorageBase"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (torch::isStorage(obj)) {
|
|
value = torch::createStorage(obj);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static handle cast(
|
|
const at::Storage& src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(torch::createPyObject(src));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<at::Generator> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (THPGenerator_Check(obj)) {
|
|
value = reinterpret_cast<THPGenerator*>(obj)->cdata;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static handle cast(
|
|
const at::Generator& src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(THPGenerator_Wrap(src));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<at::IntArrayRef> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::IntArrayRef, _("Tuple[int, ...]"));
|
|
|
|
bool load(handle src, bool);
|
|
static handle cast(
|
|
at::IntArrayRef src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
|
|
private:
|
|
std::vector<int64_t> v_value;
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<at::SymIntArrayRef> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("List[int]"));
|
|
|
|
bool load(handle src, bool);
|
|
static handle cast(
|
|
at::SymIntArrayRef src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
|
|
private:
|
|
std::vector<c10::SymInt> v_value;
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<at::ArrayRef<c10::SymNode>> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::ArrayRef<c10::SymNode>, _("List[SymNode]"));
|
|
|
|
bool load(handle src, bool);
|
|
static handle cast(
|
|
at::ArrayRef<c10::SymNode> src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
|
|
private:
|
|
std::vector<c10::SymNode> v_value;
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<at::MemoryFormat> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (THPMemoryFormat_Check(obj)) {
|
|
value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
static handle cast(
|
|
at::MemoryFormat src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(torch::utils::getTHPMemoryFormat(src));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<at::Device> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(at::Device, _("torch.device"));
|
|
|
|
// PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device
|
|
// cannot be default-initialized, we provide this constructor to explicitly
|
|
// initialize that field. The value doesn't matter as it will be overwritten
|
|
// after a successful call to load.
|
|
type_caster() : value(c10::kCPU) {}
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
if (THPDevice_Check(obj)) {
|
|
value = reinterpret_cast<THPDevice*>(obj)->device;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static handle cast(
|
|
const at::Device& src,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
return handle(THPDevice_New(src));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct type_caster<c10::DispatchKey>
|
|
: public type_caster_base<c10::DispatchKey> {
|
|
using base = type_caster_base<c10::DispatchKey>;
|
|
c10::DispatchKey tmp;
|
|
|
|
public:
|
|
bool load(handle src, bool convert) {
|
|
if (base::load(src, convert)) {
|
|
return true;
|
|
} else if (py::isinstance(
|
|
src, py::module_::import("builtins").attr("str"))) {
|
|
tmp = c10::parseDispatchKey(py::cast<std::string>(src));
|
|
value = &tmp;
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
static handle cast(
|
|
c10::DispatchKey src,
|
|
return_value_policy policy,
|
|
handle parent) {
|
|
return base::cast(src, policy, parent);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<c10::Scalar> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(
|
|
c10::Scalar,
|
|
_("Union[Number, torch.SymInt, torch.SymFloat]"));
|
|
bool load(py::handle src, bool);
|
|
|
|
static py::handle cast(
|
|
const c10::Scalar& si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<c10::SymInt> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(c10::SymInt, _("Union[int, torch.SymInt]"));
|
|
bool load(py::handle src, bool);
|
|
|
|
static py::handle cast(
|
|
c10::SymInt si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
};
|
|
|
|
template <>
|
|
struct TORCH_PYTHON_API type_caster<c10::SymFloat> {
|
|
public:
|
|
PYBIND11_TYPE_CASTER(c10::SymFloat, _("float"));
|
|
bool load(py::handle src, bool);
|
|
|
|
static py::handle cast(
|
|
c10::SymFloat si,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */);
|
|
};
|
|
|
|
template <typename T>
|
|
struct type_caster<c10::complex<T>> {
|
|
public:
|
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
|
PYBIND11_TYPE_CASTER(c10::complex<T>, _("complex"));
|
|
|
|
bool load(handle src, bool) {
|
|
PyObject* obj = src.ptr();
|
|
|
|
// Refered from `THPUtils_unpackComplexDouble`
|
|
Py_complex py_complex = PyComplex_AsCComplex(obj);
|
|
if (py_complex.real == -1.0 && PyErr_Occurred()) {
|
|
return false;
|
|
}
|
|
|
|
// Python's Complex is always double precision.
|
|
value = c10::complex<double>(py_complex.real, py_complex.imag);
|
|
return true;
|
|
}
|
|
|
|
static handle cast(
|
|
const c10::complex<T>& complex,
|
|
return_value_policy /* policy */,
|
|
handle /* parent */) {
|
|
// Python only knows double precision complex.
|
|
return handle(PyComplex_FromDoubles(complex.real(), complex.imag()));
|
|
}
|
|
};
|
|
|
|
// Pybind11 bindings for our optional and variant types.
|
|
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
|
|
template <typename T>
|
|
struct type_caster<c10::optional<T>> : optional_caster<c10::optional<T>> {};
|
|
|
|
template <typename... Ts>
|
|
struct C10_MPARK_VISIBILITY_HIDDEN type_caster<c10::variant<Ts...>>
|
|
: variant_caster<c10::variant<Ts...>> {};
|
|
} // namespace detail
|
|
} // namespace pybind11
|
|
|
|
namespace torch {
|
|
namespace impl {
|
|
|
|
// Use this function if you have a C++ object that is used from both C++
|
|
// and Python contexts, and you need its GIL to be released when you
|
|
// destruct it in the Python context.
|
|
//
|
|
// This function is a valid shared_ptr destructor and can be used to
|
|
// conveniently allocate a shared_ptr to an object whose destructor will be run
|
|
// without the GIL. Pass it as the second argument to shared_ptr, e.g.,
|
|
//
|
|
// shared_ptr<T>(new T(), destroy_without_gil<T>)
|
|
//
|
|
// Attaching the GIL release logic to the holder pointer rather than the
|
|
// actual destructor of T is helpful when T is Python-agnostic and
|
|
// shouldn't refer to the PYthon API.
|
|
//
|
|
// Note there are limitations to the correctness of code that makes use of this.
|
|
// In particular, if a shared_ptr is constructed from C++ code without this
|
|
// destructor and then passed to pybind11, pybind11 will happily take ownership
|
|
// of the shared_ptr (and be willing to destruct it from a context where it is
|
|
// holding the GIL). unique_ptr with a type branded deleter is less prone to
|
|
// this problem, because a stock deleter unique_ptr is not convertible with it.
|
|
// I plan to mitigate this problem by adding DEBUG-only asserts to the true C++
|
|
// destructors that the GIL is not held (using a virtual call to get to the
|
|
// Python interpreter); alternately, we could use a virtual call to simply
|
|
// ensure we release the GIL in the C++ destructor, however, this is a layering
|
|
// violation (why does code that is ostensibly Python agnostic calling into the
|
|
// GIL).
|
|
//
|
|
// Adapted from
|
|
// https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510
|
|
template <typename T>
|
|
inline void destroy_without_gil(T* ptr) {
|
|
// Because the ownership of a shared_ptr is diffuse, it's not possible to
|
|
// necessarily predict whether or not the last reference to an object will
|
|
// be destructed from Python or C++. This means that in the destructor here,
|
|
// we don't necessarily know if we actually have the GIL or not; in fact,
|
|
// we don't even know if the Python interpreter still exists! Thus, we have
|
|
// to test for it before releasing the GIL.
|
|
//
|
|
// PyGILState_Check is hopefully self explanatory. But Py_IsInitialized or
|
|
// _PyIsFinalizing? Both get set at the same time during the Python
|
|
// destruction process:
|
|
// https://github.com/python/cpython/blob/d92513390a1a0da781bb08c284136f4d7abea36d/Python/pylifecycle.c#L1716-L1717
|
|
// so the operant question is whether or not you want to release the GIL after
|
|
// finalization has completed (and there is just no Python interpreter).
|
|
// Clearly there is no need to release GIL in that state, so we want
|
|
// Py_IsInitialized.
|
|
if (Py_IsInitialized() && PyGILState_Check()) {
|
|
pybind11::gil_scoped_release nogil;
|
|
delete ptr;
|
|
} else {
|
|
delete ptr;
|
|
}
|
|
}
|
|
|
|
} // namespace impl
|
|
} // namespace torch
|