pytorch/c10/core/impl/PyInterpreter.cpp
Edward Z. Yang 3b6588ab74 Consistent compute numel/contiguous strategy with SymInts (#85858)
Previously, our handling for contiguity was inconsistent in the following ways:

- is_strides_like 2d/3d and is_non_overlapping_and_dense always were computed
  based on sizes_and_strides_, even if you had symbolic ints
- Furthermore, even if you set custom policy for strides, these quantities were
  not overridable by subclasses
- Furthermore, we didn't even store these fields on ExtraMeta
- We duplicate implementations of compute_contiguous (plain, channels last,
  channels last 3d)
- We inconsistently called refresh_numel()/refresh_contiguous(), versus
  recomputing it ourselves

This factor makes a consistent strategy for all of the boolean fields, and
for numel computation.  After this refactor:

- All layout boolean fields are interposable via strides policy
  and can be overridden from Python; you will never access a garbage field
- All layout boolean fields are on ExtraMeta
- You can always call refresh_numel/contiguous, no matter if your Tensor is
  contiguous or not
- The numel/layout boolean fields are always populated consistently with
  the sizes strides fields (either on Tensor or ExtraMeta), even if you
  have custom policy
- There is only one implementation of the actual computation logic

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D39907696](https://our.internmc.facebook.com/intern/diff/D39907696)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85858
Approved by: https://github.com/albanD
2022-09-30 21:26:34 +00:00

97 lines
3.1 KiB
C++

#include <c10/core/SymIntArrayRef.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/PyInterpreter.h>
namespace c10 {
namespace impl {
struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
std::string name() const override {
return "<unloaded interpreter>";
}
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
#define PANIC(m) \
TORCH_INTERNAL_ASSERT( \
0, \
"attempted to call " #m \
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
PANIC(detach);
}
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override {
PANIC(dispatch);
}
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
torch::jit::Stack* stack) const override {
PANIC(python_dispatcher);
}
bool is_contiguous(const TensorImpl* self, at::MemoryFormat) const override {
PANIC(is_contiguous);
}
bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
const override {
PANIC(is_strides_like);
}
bool is_non_overlapping_and_dense(const TensorImpl* self) const override {
PANIC(is_non_overlapping_and_dense);
}
c10::Device device(const TensorImpl* self) const override {
PANIC(device);
}
int64_t dim(const TensorImpl* self) const override {
PANIC(dim);
}
c10::IntArrayRef strides(const TensorImpl* self) const override {
PANIC(strides);
}
c10::IntArrayRef sizes(const TensorImpl* self) const override {
PANIC(sizes);
}
c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const override {
PANIC(sym_sizes);
}
c10::Layout layout(const TensorImpl* self) const override {
PANIC(layout);
}
c10::SymInt sym_numel(const TensorImpl* self) const override {
PANIC(sym_numel);
}
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override {
PANIC(sym_strides);
}
c10::SymInt sym_storage_offset(const TensorImpl* self) const override {
PANIC(sym_storage_offset);
}
// Just swallow the event, don't do anything
void trace_gpu_event_creation(uintptr_t event) const override {}
void trace_gpu_event_deletion(uintptr_t event) const override {}
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
const override {}
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {}
void trace_gpu_memory_allocation(uintptr_t ptr) const override {}
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {}
void trace_gpu_stream_creation(uintptr_t stream) const override {}
void trace_gpu_device_synchronization() const override {}
void trace_gpu_stream_synchronization(uintptr_t stream) const override {}
void trace_gpu_event_synchronization(uintptr_t event) const override {}
};
void PyInterpreter::disarm() noexcept {
// Intentionally leaked
static PyInterpreterVTable* noop_vtable = new NoopPyInterpreterVTable();
vtable_ = noop_vtable;
}
} // namespace impl
} // namespace c10