mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
A longstanding confusion in the implementation of fake tensor and proxy tensor is what to do about torch.ops.aten.sym_sizes and related calls. In particular, when you have a tensor that (1) has symbolic shapes and (2) has a `__torch_dispatch__` call, previously, you would always get `__torch_dispatch__` calls for sizes/strides query, *even if you didn't request it* via the dispatch kwargs in `make_wrapper_subclass`. The reason for this is because we were previously mixing several concepts: "I want to dispatch to Python", "I want to call a virtual method" and "I have dynamic shapes". A single boolean variable controlled all of these things, and so it was not possible to understand inside TensorImpl what the user had actually originally requested. In this PR, we track each of these concepts individually so that we can preserve user intent. Then, we combine these into a single "policy" variable that controls whether or not we can use the fastpath or not. For the policy to trigger, we only need one of the exceptional cases to be true. Billing of changes: * Rename `set_sizes_strides_policy` to `set_custom_sizes_strides`; in general, you cannot DIRECTLY set policy; you have to indirectly set it by the public functions. * Some helpers for sizes and strides, since it's more complicated (as it is an enum, rather than just bools as is the case for device and layout). `matches_python_custom` is used to test the Python dispatch user ask. `matches_policy` does the policy test (only used in the user facing functions.) * I reorged the accessor methods so that they are more logical. This makes the diff bad, so I recommend reading the final code directly. * The default custom implementations now more reliably call their default() implementations * As bonus refactor, I devirtualized some functions that don't need to be virtual * `set_sym_sizes_and_strides` is renamed to `set_sizes_and_strides` to make it easier to use in template contexts; it optionally takes a storage offset now so you can set all three values at the same time. If you use the SymInt overload but there are no symbolic integers, we give you a normal resize. * This adds `sym_storage_offset` since we had that in the symbolic shapes branch and there's no reason not to put it in (and it reduces merge conflicts) Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84641 Approved by: https://github.com/wconstab
79 lines
2.4 KiB
C++
79 lines
2.4 KiB
C++
#include <c10/core/SymIntArrayRef.h>
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <c10/core/impl/PyInterpreter.h>
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
|
std::string name() const override {
|
|
return "<unloaded interpreter>";
|
|
}
|
|
|
|
void decref(PyObject* pyobj, bool is_tensor) const override {} // do nothing
|
|
|
|
#define PANIC(m) \
|
|
TORCH_INTERNAL_ASSERT( \
|
|
0, \
|
|
"attempted to call " #m \
|
|
" on a Tensor with nontrivial PyObject after corresponding interpreter died")
|
|
|
|
c10::intrusive_ptr<TensorImpl> detach(const TensorImpl* self) const override {
|
|
PANIC(detach);
|
|
}
|
|
|
|
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
|
const override {
|
|
PANIC(dispatch);
|
|
}
|
|
|
|
bool is_contiguous(const TensorImpl* self) const override {
|
|
PANIC(is_contiguous);
|
|
}
|
|
c10::Device device(const TensorImpl* self) const override {
|
|
PANIC(device);
|
|
}
|
|
int64_t dim(const TensorImpl* self) const override {
|
|
PANIC(dim);
|
|
}
|
|
c10::IntArrayRef strides(const TensorImpl* self) const override {
|
|
PANIC(strides);
|
|
}
|
|
c10::IntArrayRef sizes(const TensorImpl* self) const override {
|
|
PANIC(sizes);
|
|
}
|
|
c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const override {
|
|
PANIC(sym_sizes);
|
|
}
|
|
c10::Layout layout(const TensorImpl* self) const override {
|
|
PANIC(layout);
|
|
}
|
|
c10::SymInt sym_numel(const TensorImpl* self) const override {
|
|
PANIC(sym_numel);
|
|
}
|
|
c10::SymIntArrayRef sym_strides(const TensorImpl* self) const override {
|
|
PANIC(sym_strides);
|
|
}
|
|
c10::SymInt sym_storage_offset(const TensorImpl* self) const override {
|
|
PANIC(sym_storage_offset);
|
|
}
|
|
|
|
// Just swallow the event, don't do anything
|
|
void trace_gpu_event_creation(uintptr_t event) const override {}
|
|
void trace_gpu_event_deletion(uintptr_t event) const override {}
|
|
void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
|
|
const override {}
|
|
void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override {}
|
|
void trace_gpu_memory_allocation(uintptr_t ptr) const override {}
|
|
void trace_gpu_memory_deallocation(uintptr_t ptr) const override {}
|
|
void trace_gpu_stream_creation(uintptr_t stream) const override {}
|
|
};
|
|
|
|
void PyInterpreter::disarm() noexcept {
|
|
static NoopPyInterpreterVTable noop_vtable;
|
|
vtable_ = &noop_vtable;
|
|
}
|
|
|
|
} // namespace impl
|
|
} // namespace c10
|