pytorch/c10/core/UndefinedTensorImpl.h
Edward Z. Yang 2896f81dd4 Consolidate customization contiguous/sizes policy into unified policy
Prior to this PR, we had a mish-mash of ways of getting unconventional
sizes/strides behavior:

- In OSS (but not in fbcode), some methods are virtual and you
  can override them directly

- There is a is_contiguous policy which is a bitfield tag that lets
  you toggle is_contiguous to error or hit a virtual method
  is_contiguous_custom if it is set.  Ordinarily is_contiguous()
  is virtual and you can just override it, but this works EVEN IF
  is_contiguous() is non-virtual (e.g., in fbcode)

- There is also a sizes policy which is the same idea but for sizes

This PR unifies these mechanisms, and in doing so, eliminates the
maybe virtual/not-virtualness of the methods in question.  The primary
downside of this change is that it is BC-breaking (but the BC break is
very easy to fix!)

The new scheme works like this: we have three levels of policy for
sizes/strides (order matters).

- The Default policy is a conventional dense tensor, where we use
  all of the built-in fields to directly represent the
  sizes/strides/numel/contiguity of the tensor, and it is possible
  to bypass virtual call entirely.

- The CustomStrides policy represent tensors which have a custom
  notion of strides (most typically, that they don't support them),
  shunting strides() and is_contiguous() to virtual methods
  strides_custom() and is_contiguous_custom().  This INCLUDES handling
  for contiguity, since they typically go hand-in-hand (although
  the situation is murky with batched tensors).  The default
  implementations of these functions raise errors saying the tensor
  doesn't support them.

- The CustomSizes policy represent tensors which have a custom
  notion of sizes (the two notable examples are nested tensor, which
  doesn't have a representation of sizes in the conventional form, and
  XLA/LTC tensor, which synchronizes its sizes with an underlying
  compiler backend).  This shunts sizes(), numel() and dim() (along
  with everything from strides) to _custom() variants.

There is no special policy for erroring; instead, we just do a vcall
and expect the virtual method to raise an exception (the performance
hit from the vcall doesn't matter because you're about to raise a C++
exception anyway).  The default implementations of all overridable
functions are available at _default() which is helpful in some
situations when you just want to do a "sync" and then run the
conventional semantics.

This PR could be extended further in two ways but I did not do them
due to time constraints:

- Ideally, all TENSORIMPL_MAYBE_VIRTUAL would be eliminated from
  TensorImpl, by using the same policy trick.

- set_size and set_stride are still virtual; it's not entirely clear
  the same trick should be used here though as these methods are
  deprecated.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77036

Approved by: https://github.com/bdhirsh
2022-05-11 00:23:07 +00:00

36 lines
863 B
C++

#pragma once
#include <c10/core/TensorImpl.h>
namespace c10 {
struct C10_API UndefinedTensorImpl final : public TensorImpl {
public:
// Without this, we get:
// error: identifier "at::UndefinedTensorImpl::_singleton" is undefined in
// device code
// (ostensibly because the constexpr tricks MSVC into trying to compile this
// function for device as well).
#ifdef _WIN32
static inline TensorImpl* singleton() {
#else
static constexpr inline TensorImpl* singleton() {
#endif
return &_singleton;
}
#ifdef DEBUG
bool has_storage() const override;
#endif
void set_storage_offset(int64_t offset) override;
protected:
bool is_contiguous_custom(MemoryFormat format) const override;
private:
UndefinedTensorImpl();
static UndefinedTensorImpl _singleton;
const char* tensorimpl_type_name() const override;
};
} // namespace c10