pytorch/c10/core/UndefinedTensorImpl.cpp
Edward Z. Yang 2896f81dd4 Consolidate customization contiguous/sizes policy into unified policy
Prior to this PR, we had a mish-mash of ways of getting unconventional
sizes/strides behavior:

- In OSS (but not in fbcode), some methods are virtual and you
  can override them directly

- There is a is_contiguous policy which is a bitfield tag that lets
  you toggle is_contiguous to error or hit a virtual method
  is_contiguous_custom if it is set.  Ordinarily is_contiguous()
  is virtual and you can just override it, but this works EVEN IF
  is_contiguous() is non-virtual (e.g., in fbcode)

- There is also a sizes policy which is the same idea but for sizes

This PR unifies these mechanisms, and in doing so, eliminates the
maybe virtual/not-virtualness of the methods in question.  The primary
downside of this change is that it is BC-breaking (but the BC break is
very easy to fix!)

The new scheme works like this: we have three levels of policy for
sizes/strides (order matters).

- The Default policy is a conventional dense tensor, where we use
  all of the built-in fields to directly represent the
  sizes/strides/numel/contiguity of the tensor, and it is possible
  to bypass virtual call entirely.

- The CustomStrides policy represent tensors which have a custom
  notion of strides (most typically, that they don't support them),
  shunting strides() and is_contiguous() to virtual methods
  strides_custom() and is_contiguous_custom().  This INCLUDES handling
  for contiguity, since they typically go hand-in-hand (although
  the situation is murky with batched tensors).  The default
  implementations of these functions raise errors saying the tensor
  doesn't support them.

- The CustomSizes policy represent tensors which have a custom
  notion of sizes (the two notable examples are nested tensor, which
  doesn't have a representation of sizes in the conventional form, and
  XLA/LTC tensor, which synchronizes its sizes with an underlying
  compiler backend).  This shunts sizes(), numel() and dim() (along
  with everything from strides) to _custom() variants.

There is no special policy for erroring; instead, we just do a vcall
and expect the virtual method to raise an exception (the performance
hit from the vcall doesn't matter because you're about to raise a C++
exception anyway).  The default implementations of all overridable
functions are available at _default() which is helpful in some
situations when you just want to do a "sync" and then run the
conventional semantics.

This PR could be extended further in two ways but I did not do them
due to time constraints:

- Ideally, all TENSORIMPL_MAYBE_VIRTUAL would be eliminated from
  TensorImpl, by using the same policy trick.

- set_size and set_stride are still virtual; it's not entirely clear
  the same trick should be used here though as these methods are
  deprecated.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77036

Approved by: https://github.com/bdhirsh
2022-05-11 00:23:07 +00:00

38 lines
1.1 KiB
C++

#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/Exception.h>
namespace c10 {
// should this use the globalContext? Can it get a context passed in somehow?
UndefinedTensorImpl::UndefinedTensorImpl()
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
set_storage_access_should_throw();
// TODO: accessing the sizes on an undefined tensor is not meaningful
// and should error too, but empirically it does not!
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
}
bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const {
return is_contiguous_default(format);
}
#ifdef DEBUG
bool UndefinedTensorImpl::has_storage() const {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
return false;
}
#endif
void UndefinedTensorImpl::set_storage_offset(int64_t) {
TORCH_CHECK(false, "set_storage_offset() called on an undefined Tensor");
}
const char* UndefinedTensorImpl::tensorimpl_type_name() const {
return "UndefinedTensorImpl";
}
UndefinedTensorImpl UndefinedTensorImpl::_singleton;
} // namespace c10