mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Prior to this PR, we had a mish-mash of ways of getting unconventional sizes/strides behavior: - In OSS (but not in fbcode), some methods are virtual and you can override them directly - There is a is_contiguous policy which is a bitfield tag that lets you toggle is_contiguous to error or hit a virtual method is_contiguous_custom if it is set. Ordinarily is_contiguous() is virtual and you can just override it, but this works EVEN IF is_contiguous() is non-virtual (e.g., in fbcode) - There is also a sizes policy which is the same idea but for sizes This PR unifies these mechanisms, and in doing so, eliminates the maybe virtual/not-virtualness of the methods in question. The primary downside of this change is that it is BC-breaking (but the BC break is very easy to fix!) The new scheme works like this: we have three levels of policy for sizes/strides (order matters). - The Default policy is a conventional dense tensor, where we use all of the built-in fields to directly represent the sizes/strides/numel/contiguity of the tensor, and it is possible to bypass virtual call entirely. - The CustomStrides policy represent tensors which have a custom notion of strides (most typically, that they don't support them), shunting strides() and is_contiguous() to virtual methods strides_custom() and is_contiguous_custom(). This INCLUDES handling for contiguity, since they typically go hand-in-hand (although the situation is murky with batched tensors). The default implementations of these functions raise errors saying the tensor doesn't support them. - The CustomSizes policy represent tensors which have a custom notion of sizes (the two notable examples are nested tensor, which doesn't have a representation of sizes in the conventional form, and XLA/LTC tensor, which synchronizes its sizes with an underlying compiler backend). This shunts sizes(), numel() and dim() (along with everything from strides) to _custom() variants. There is no special policy for erroring; instead, we just do a vcall and expect the virtual method to raise an exception (the performance hit from the vcall doesn't matter because you're about to raise a C++ exception anyway). The default implementations of all overridable functions are available at _default() which is helpful in some situations when you just want to do a "sync" and then run the conventional semantics. This PR could be extended further in two ways but I did not do them due to time constraints: - Ideally, all TENSORIMPL_MAYBE_VIRTUAL would be eliminated from TensorImpl, by using the same policy trick. - set_size and set_stride are still virtual; it's not entirely clear the same trick should be used here though as these methods are deprecated. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/77036 Approved by: https://github.com/bdhirsh
36 lines
863 B
C++
36 lines
863 B
C++
#pragma once
|
|
|
|
#include <c10/core/TensorImpl.h>
|
|
|
|
namespace c10 {
|
|
|
|
struct C10_API UndefinedTensorImpl final : public TensorImpl {
|
|
public:
|
|
// Without this, we get:
|
|
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in
|
|
// device code
|
|
// (ostensibly because the constexpr tricks MSVC into trying to compile this
|
|
// function for device as well).
|
|
#ifdef _WIN32
|
|
static inline TensorImpl* singleton() {
|
|
#else
|
|
static constexpr inline TensorImpl* singleton() {
|
|
#endif
|
|
return &_singleton;
|
|
}
|
|
#ifdef DEBUG
|
|
bool has_storage() const override;
|
|
#endif
|
|
void set_storage_offset(int64_t offset) override;
|
|
|
|
protected:
|
|
bool is_contiguous_custom(MemoryFormat format) const override;
|
|
|
|
private:
|
|
UndefinedTensorImpl();
|
|
static UndefinedTensorImpl _singleton;
|
|
const char* tensorimpl_type_name() const override;
|
|
};
|
|
|
|
} // namespace c10
|