pytorch/c10/core/impl/PyInterpreter.cpp
Edward Z. Yang 74877943b8 Don't invoke mode as overloaded argument in torch dispatch (#80992)
I noticed that in some situations torch dispatch modes were being
invoked with a mode active, which isn't supposed to happen (we
disable modes before calling into the user mode.)  I also noticed that
I was getting a warning that I had a deprecated non-static definition of
torch dispatch on an argument even though there wasn't any.

It turns out this is because modes were part of the overloaded arguments
list in the Python fallback kernel for torch dispatch.  This is wrong;
instead we should rely on the actual dispatching function to consult
modes.  This makes the code simpler.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80992
Approved by: https://github.com/zou3519
2022-07-06 23:45:59 +00:00

95 lines
2.9 KiB
C++

#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/PyInterpreter.h>
namespace c10 {
namespace impl {
static std::string noop_name_fn(const PyInterpreter*) {
return "<unloaded interpreter>";
}
static void noop_decref_fn(const PyInterpreter*, PyObject*, bool) {
// no-op
}
static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
const PyInterpreter*,
const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to detach (shallow_copy_and_detach) Tensor with nontrivial PyObject after corresponding interpreter died");
}
static void noop_dispatch_fn(
const PyInterpreter*,
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static bool noop_is_contiguous_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `is_contiguous` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::Device noop_device_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `device` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static int64_t noop_dim_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `dim` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::IntArrayRef noop_strides_fn(
const PyInterpreter*,
const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `strides` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::IntArrayRef noop_sizes_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `sizes` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::SymIntArrayRef noop_sym_sizes_fn(
const PyInterpreter*,
const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `sym_sizes` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
static c10::Layout noop_layout_fn(const PyInterpreter*, const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to call `layout` on Tensor with nontrivial PyObject after corresponding interpreter died");
}
void PyInterpreter::disarm() noexcept {
name_fn_ = &noop_name_fn;
decref_fn_ = &noop_decref_fn;
detach_fn_ = &noop_detach_fn;
dispatch_fn_ = &noop_dispatch_fn;
is_contiguous_fn_ = &noop_is_contiguous_fn;
device_fn_ = &noop_device_fn;
dim_fn_ = &noop_dim_fn;
strides_fn_ = &noop_strides_fn;
sizes_fn_ = &noop_sizes_fn;
sym_sizes_fn_ = &noop_sym_sizes_fn;
layout_fn_ = &noop_layout_fn;
}
} // namespace impl
} // namespace c10