mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Prior to this PR, we had a mish-mash of ways of getting unconventional sizes/strides behavior: - In OSS (but not in fbcode), some methods are virtual and you can override them directly - There is a is_contiguous policy which is a bitfield tag that lets you toggle is_contiguous to error or hit a virtual method is_contiguous_custom if it is set. Ordinarily is_contiguous() is virtual and you can just override it, but this works EVEN IF is_contiguous() is non-virtual (e.g., in fbcode) - There is also a sizes policy which is the same idea but for sizes This PR unifies these mechanisms, and in doing so, eliminates the maybe virtual/not-virtualness of the methods in question. The primary downside of this change is that it is BC-breaking (but the BC break is very easy to fix!) The new scheme works like this: we have three levels of policy for sizes/strides (order matters). - The Default policy is a conventional dense tensor, where we use all of the built-in fields to directly represent the sizes/strides/numel/contiguity of the tensor, and it is possible to bypass virtual call entirely. - The CustomStrides policy represent tensors which have a custom notion of strides (most typically, that they don't support them), shunting strides() and is_contiguous() to virtual methods strides_custom() and is_contiguous_custom(). This INCLUDES handling for contiguity, since they typically go hand-in-hand (although the situation is murky with batched tensors). The default implementations of these functions raise errors saying the tensor doesn't support them. - The CustomSizes policy represent tensors which have a custom notion of sizes (the two notable examples are nested tensor, which doesn't have a representation of sizes in the conventional form, and XLA/LTC tensor, which synchronizes its sizes with an underlying compiler backend). This shunts sizes(), numel() and dim() (along with everything from strides) to _custom() variants. There is no special policy for erroring; instead, we just do a vcall and expect the virtual method to raise an exception (the performance hit from the vcall doesn't matter because you're about to raise a C++ exception anyway). The default implementations of all overridable functions are available at _default() which is helpful in some situations when you just want to do a "sync" and then run the conventional semantics. This PR could be extended further in two ways but I did not do them due to time constraints: - Ideally, all TENSORIMPL_MAYBE_VIRTUAL would be eliminated from TensorImpl, by using the same policy trick. - set_size and set_stride are still virtual; it's not entirely clear the same trick should be used here though as these methods are deprecated. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/77036 Approved by: https://github.com/bdhirsh
38 lines
1.1 KiB
C++
38 lines
1.1 KiB
C++
#include <c10/core/UndefinedTensorImpl.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace c10 {
|
|
|
|
// should this use the globalContext? Can it get a context passed in somehow?
|
|
UndefinedTensorImpl::UndefinedTensorImpl()
|
|
: TensorImpl(DispatchKey::Undefined, caffe2::TypeMeta(), c10::nullopt) {
|
|
set_storage_access_should_throw();
|
|
// TODO: accessing the sizes on an undefined tensor is not meaningful
|
|
// and should error too, but empirically it does not!
|
|
set_sizes_strides_policy(SizesStridesPolicy::CustomStrides);
|
|
}
|
|
|
|
bool UndefinedTensorImpl::is_contiguous_custom(MemoryFormat format) const {
|
|
return is_contiguous_default(format);
|
|
}
|
|
|
|
#ifdef DEBUG
|
|
bool UndefinedTensorImpl::has_storage() const {
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
!storage_, "UndefinedTensorImpl assumes that storage_ is never set");
|
|
return false;
|
|
}
|
|
#endif
|
|
|
|
void UndefinedTensorImpl::set_storage_offset(int64_t) {
|
|
TORCH_CHECK(false, "set_storage_offset() called on an undefined Tensor");
|
|
}
|
|
|
|
const char* UndefinedTensorImpl::tensorimpl_type_name() const {
|
|
return "UndefinedTensorImpl";
|
|
}
|
|
|
|
UndefinedTensorImpl UndefinedTensorImpl::_singleton;
|
|
|
|
} // namespace c10
|