mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Still a continued work for clean up c10/util/*.h Pull Request resolved: https://github.com/pytorch/pytorch/pull/116326 Approved by: https://github.com/Skylion007
76 lines
2.1 KiB
C++
76 lines
2.1 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/TensorImpl.h>
|
|
#include <c10/core/UndefinedTensorImpl.h>
|
|
|
|
#include <utility>
|
|
|
|
namespace c10 {
|
|
// Shared ExclusivelyOwnedTraits implementation between caffe2::Tensor and
|
|
// at::TensorBase.
|
|
template <typename TensorType>
|
|
struct ExclusivelyOwnedTensorTraits {
|
|
using repr_type = TensorType;
|
|
using pointer_type = TensorType*;
|
|
using const_pointer_type = const TensorType*;
|
|
|
|
static repr_type nullRepr() {
|
|
return TensorType();
|
|
}
|
|
|
|
template <class... Args>
|
|
static repr_type createInPlace(Args&&... args) {
|
|
return TensorType(std::forward<Args>(args)...);
|
|
}
|
|
|
|
static repr_type moveToRepr(TensorType&& x) {
|
|
return std::move(x);
|
|
}
|
|
|
|
static void destroyOwned(TensorType& x) {
|
|
TensorImpl* const toDestroy = x.unsafeReleaseTensorImpl();
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
toDestroy != nullptr, "Tensor somehow got null TensorImpl?");
|
|
// May be 0 because UndefinedTensorImpl doesn't get its refcount
|
|
// incremented.
|
|
const bool isUndefined = toDestroy == UndefinedTensorImpl::singleton();
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
toDestroy->refcount_ == 1 || (toDestroy->refcount_ == 0 && isUndefined),
|
|
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
|
|
isUndefined,
|
|
" and refcount ",
|
|
toDestroy->refcount_,
|
|
", expected 1 or, if isUndefined, 0!");
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
|
toDestroy->weakcount_ == 1 ||
|
|
(toDestroy->weakcount_ == 0 &&
|
|
toDestroy == UndefinedTensorImpl::singleton()),
|
|
"ExclusivelyOwned<Tensor> destroyed with isUndefined ",
|
|
isUndefined,
|
|
" and weakcount ",
|
|
toDestroy->weakcount_,
|
|
", expected 1 or, if isUndefined, 0!");
|
|
if (!isUndefined) {
|
|
#ifndef NDEBUG
|
|
// Needed to pass the debug assertions in ~intrusive_ptr_target.
|
|
toDestroy->refcount_ = 0;
|
|
toDestroy->weakcount_ = 0;
|
|
#endif
|
|
delete toDestroy;
|
|
}
|
|
}
|
|
|
|
static TensorType take(TensorType& x) {
|
|
return std::move(x);
|
|
}
|
|
|
|
static pointer_type getImpl(repr_type& x) {
|
|
return &x;
|
|
}
|
|
|
|
static const_pointer_type getImpl(const repr_type& x) {
|
|
return &x;
|
|
}
|
|
};
|
|
} // namespace c10
|