mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49097 RFC: https://github.com/pytorch/rfcs/pull/11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D25607503 Pulled By: albanD fbshipit-source-id: f1396290de1d75760f3d380c43cdd56e86fa6099
395 lines
13 KiB
C++
395 lines
13 KiB
C++
#include <c10/core/TensorImpl.h>
|
|
|
|
#include <c10/core/Backend.h>
|
|
#include <c10/core/WrapDimMinimal.h>
|
|
#include <c10/core/impl/LocalDispatchKeySet.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
C10_DEFINE_bool(
|
|
caffe2_keep_on_shrink,
|
|
true,
|
|
"If set, keeps memory when a tensor is shrinking its size.");
|
|
|
|
C10_DEFINE_int64(
|
|
caffe2_max_keep_on_shrink_memory,
|
|
LLONG_MAX,
|
|
"The maximum memory in bytes to keep on shrink, if the difference between "
|
|
"tensor sizes is bigger than this then tensor will be reset.");
|
|
|
|
namespace c10 {
|
|
|
|
const char * const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
|
|
"is not allowed on a Tensor created from .data or .detach().\n"
|
|
"If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
|
|
"without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
|
|
"For example, change:\n"
|
|
" x.data.set_(y)\n"
|
|
"to:\n"
|
|
" with torch.no_grad():\n"
|
|
" x.set_(y)";
|
|
|
|
at::Tensor& TensorImpl::mutable_grad() {
|
|
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
|
return autograd_meta_->mutable_grad();
|
|
}
|
|
|
|
const at::Tensor& TensorImpl::grad() const {
|
|
// Yes, I know this looks really weird. But I don't really have a choice as
|
|
// long as this function returns a const reference to Tensor. I'm not
|
|
// really sure how I would have designed this API differently, but it
|
|
// is not so easy to fix right now because the mutable counterpart of
|
|
// this function must keep working so that "x.grad() = ..." keeps working
|
|
// (part of public API).
|
|
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
|
|
return autograd_meta_->grad();
|
|
}
|
|
|
|
const at::Tensor& TensorImpl::fw_grad(uint64_t level, const at::Tensor& self) const {
|
|
// See TensorImpl::grad() above for explanation about the line below
|
|
if (!autograd_meta_) return impl::GetAutogradMetaFactory()->undefined_tensor();
|
|
return autograd_meta_->fw_grad(level, self);
|
|
}
|
|
|
|
void TensorImpl::set_fw_grad(const at::Tensor& new_grad, const at::Tensor& self, uint64_t level, bool is_inplace_op) {
|
|
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
|
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
|
|
}
|
|
|
|
TensorImpl::TensorImpl(
|
|
Storage&& storage,
|
|
DispatchKeySet key_set,
|
|
const caffe2::TypeMeta data_type)
|
|
: TensorImpl(std::move(storage), key_set, data_type, storage.device()) {}
|
|
|
|
TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta data_type, c10::optional<c10::Device> device_opt)
|
|
: TensorImpl({}, key_set, data_type, std::move(device_opt)) {}
|
|
|
|
TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta data_type,
|
|
c10::optional<c10::Device> device_opt)
|
|
: storage_(std::move(storage)),
|
|
sizes_{0},
|
|
storage_offset_(0),
|
|
numel_(0),
|
|
data_type_(data_type),
|
|
device_opt_(device_opt) {
|
|
|
|
init_bitfields();
|
|
|
|
if (!key_set.empty()) {
|
|
TORCH_INTERNAL_ASSERT(data_type == ScalarType::Undefined || device_opt_.has_value());
|
|
// UndefinedTensorImpl is a singleton, so we skip logging it
|
|
C10_LOG_API_USAGE_ONCE("tensor.create");
|
|
}
|
|
// After we removed Autograd keys from globally enabled set, every Tensor must be created with
|
|
// a backend DispatchKey and an AutogradBackend key.
|
|
// We automatically add the corresponding autograd key to key_set_ so that backends can stay
|
|
// in the old way of only registering with backend key like DispatchKey::CPU.
|
|
// TODO: Ideally this logic fits best in Variable/Autograd layer so that we only
|
|
// add AutogradBackend key when the tensor requires grad.
|
|
DispatchKey k = key_set.highestPriorityBackendTypeId();
|
|
key_set_ = key_set.add(getAutogradKeyFromBackend(k));
|
|
|
|
// we would also like to check that non-cpu devices have an index, but some Caffe2 operators create
|
|
// Storages with default devices.
|
|
strides_.push_back(1);
|
|
}
|
|
|
|
IntArrayRef TensorImpl::sizes() const {
|
|
return sizes_;
|
|
}
|
|
|
|
IntArrayRef TensorImpl::strides() const {
|
|
return strides_;
|
|
}
|
|
|
|
bool TensorImpl::compute_contiguous() const {
|
|
bool is_contiguous = true;
|
|
if (is_empty())
|
|
return is_contiguous;
|
|
int64_t z = 1;
|
|
for (int64_t d = dim() - 1; d >= 0; d--) {
|
|
if (sizes_[d] != 1) {
|
|
if (strides_[d] == z) {
|
|
z *= sizes_[d];
|
|
} else {
|
|
is_contiguous = false;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return is_contiguous;
|
|
}
|
|
|
|
bool TensorImpl::compute_channels_last_contiguous_2d() const {
|
|
// Please don't combine these code, constant array is used here to let
|
|
// compiler fully unroll the loop to get better performance
|
|
switch (sizes_.size()) {
|
|
case 4:
|
|
{
|
|
int64_t expected = 1;
|
|
for (auto& d : {1, 3, 2, 0}) {
|
|
if (sizes_[d] != 1) {
|
|
if (strides_[d] != expected) {
|
|
return false;
|
|
}
|
|
expected *= sizes_[d];
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
case 3:
|
|
// TODO dim == 3 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool TensorImpl::compute_channels_last_contiguous_3d() const {
|
|
// Please don't combine these code, constant array is used here to let
|
|
// compiler fully unroll the loop to get better performance
|
|
switch (sizes_.size()) {
|
|
case 5:
|
|
{
|
|
int64_t expected = 1;
|
|
for (auto& d : {1, 4, 3, 2, 0}) {
|
|
if (sizes_[d] != 1) {
|
|
if (strides_[d] != expected) {
|
|
return false;
|
|
}
|
|
expected *= sizes_[d];
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
case 4:
|
|
// TODO dim == 4 case will be enabled once it is fully tested
|
|
return false;
|
|
default:
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool TensorImpl::compute_strides_like_channels_last_2d() const {
|
|
return is_channels_last_strides_2d(sizes_, strides_);
|
|
}
|
|
|
|
bool TensorImpl::compute_strides_like_channels_last_3d() const {
|
|
return is_channels_last_strides_3d(sizes_, strides_);
|
|
}
|
|
|
|
bool TensorImpl::compute_non_overlapping_and_dense() const {
|
|
if (dim() == 1) {
|
|
return sizes_[0] < 2 || strides_[0] == 1;
|
|
}
|
|
SmallVector<int64_t,5> perm;
|
|
perm.resize(dim());
|
|
for (int64_t i = 0; i < dim(); i ++) {
|
|
perm[i] = i;
|
|
}
|
|
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
|
|
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
|
|
if (sizes_[a] < 2) {
|
|
return false;
|
|
} else if (sizes_[b] < 2) {
|
|
return true;
|
|
}
|
|
return strides_[a] < strides_[b];
|
|
});
|
|
auto require_stride = 1;
|
|
for (int64_t i = 0; i < dim(); i ++) {
|
|
if (sizes_[perm[i]] < 2) {
|
|
return true;
|
|
}
|
|
if (strides_[perm[i]] != require_stride) {
|
|
return false;
|
|
}
|
|
require_stride *= sizes_[perm[i]];
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void TensorImpl::release_resources() {
|
|
autograd_meta_.reset();
|
|
if (storage_) {
|
|
storage_ = {};
|
|
}
|
|
}
|
|
|
|
int64_t TensorImpl::dim() const {
|
|
return sizes_.size();
|
|
}
|
|
|
|
int64_t TensorImpl::size(int64_t d) const {
|
|
d = at::maybe_wrap_dim(d, dim(), false);
|
|
return sizes_[d];
|
|
}
|
|
|
|
int64_t TensorImpl::stride(int64_t d) const {
|
|
d = at::maybe_wrap_dim(d, dim(), false);
|
|
return strides_[d];
|
|
}
|
|
|
|
bool TensorImpl::has_storage() const {
|
|
return storage_;
|
|
}
|
|
|
|
bool TensorImpl::is_contiguous(at::MemoryFormat memory_format) const {
|
|
#ifdef DEBUG
|
|
AT_ASSERT(compute_contiguous() == is_contiguous_);
|
|
#endif
|
|
if (memory_format == at::MemoryFormat::ChannelsLast) {
|
|
return is_channels_last_contiguous_;
|
|
}
|
|
else if (memory_format == at::MemoryFormat::ChannelsLast3d) {
|
|
return is_channels_last_3d_contiguous_;
|
|
}
|
|
return is_contiguous_;
|
|
}
|
|
|
|
const Storage& TensorImpl::storage() const {
|
|
return storage_;
|
|
}
|
|
|
|
static void deletePlacementDeleteContext(void* ptr) {
|
|
delete static_cast<PlacementDeleteContext*>(ptr);
|
|
}
|
|
|
|
at::DataPtr PlacementDeleteContext::makeDataPtr(
|
|
at::DataPtr&& data_ptr,
|
|
PlacementDtor placement_dtor,
|
|
size_t size,
|
|
at::Device device) {
|
|
auto* ptr = data_ptr.get();
|
|
return {ptr,
|
|
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
|
|
&deletePlacementDeleteContext,
|
|
device};
|
|
}
|
|
|
|
AutogradMetaInterface::~AutogradMetaInterface() {}
|
|
|
|
void TensorImpl::set_requires_grad(bool requires_grad) {
|
|
if (!requires_grad && !autograd_meta_) return;
|
|
if (!autograd_meta_) autograd_meta_ = impl::GetAutogradMetaFactory()->make();
|
|
// NB: In principle, setting requires_grad to false could result in
|
|
// the AutogradMeta becoming equal to a default constructed state,
|
|
// in which case we could apply the nullptr AutogradMeta optimization
|
|
// (see autograd_meta_ docs). But we don't do this right now. Note
|
|
// that it is unsound to unconditionally set AutogradMeta to false
|
|
// when you set requires_grad to False, as there may be nontrivial
|
|
// information content in the other fields; for example, we may
|
|
// have set the string name for a Variable, or there may be hooks
|
|
// registered for it.
|
|
autograd_meta_->set_requires_grad(requires_grad, this);
|
|
}
|
|
|
|
bool TensorImpl::requires_grad() const {
|
|
if (!autograd_meta_) return false;
|
|
return autograd_meta_->requires_grad();
|
|
}
|
|
|
|
void TensorImpl::set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
|
// NB: autograd_meta may be null! That just means it's the default
|
|
// constructor
|
|
autograd_meta_ = std::move(autograd_meta);
|
|
}
|
|
|
|
c10::AutogradMetaInterface* TensorImpl::autograd_meta() const {
|
|
// NB: Might return null!
|
|
return autograd_meta_.get();
|
|
}
|
|
|
|
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
|
const c10::VariableVersion& version_counter,
|
|
bool allow_tensor_metadata_change) const {
|
|
auto impl = c10::make_intrusive<TensorImpl>(
|
|
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
|
key_set_, data_type_, device_opt_);
|
|
copy_tensor_metadata(
|
|
/*src_impl=*/this,
|
|
/*dest_impl=*/impl.get(),
|
|
/*version_counter=*/version_counter,
|
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
|
impl->refresh_numel();
|
|
impl->refresh_contiguous();
|
|
return impl;
|
|
}
|
|
|
|
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
|
|
c10::VariableVersion&& version_counter,
|
|
bool allow_tensor_metadata_change) const {
|
|
auto impl = c10::make_intrusive<TensorImpl>(
|
|
// No need to populate Storage; copy_tensor_metadata will do it for us.
|
|
key_set_, data_type_, device_opt_);
|
|
copy_tensor_metadata(
|
|
/*src_impl=*/this,
|
|
/*dest_impl=*/impl.get(),
|
|
/*version_counter=*/std::move(version_counter),
|
|
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
|
|
impl->refresh_numel();
|
|
impl->refresh_contiguous();
|
|
return impl;
|
|
}
|
|
|
|
void TensorImpl::copy_tensor_metadata_except_version_counter(
|
|
const TensorImpl* src_impl,
|
|
TensorImpl* dest_impl,
|
|
bool allow_tensor_metadata_change) {
|
|
dest_impl->storage_ = src_impl->storage_;
|
|
dest_impl->sizes_ = src_impl->sizes_;
|
|
dest_impl->strides_ = src_impl->strides_;
|
|
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
|
dest_impl->data_type_ = src_impl->data_type_;
|
|
dest_impl->device_opt_ = src_impl->device_opt_;
|
|
dest_impl->key_set_ = src_impl->key_set_;
|
|
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
|
dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_;
|
|
dest_impl->is_channels_last_3d_contiguous_ = src_impl->is_channels_last_3d_contiguous_;
|
|
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
|
|
dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_;
|
|
dest_impl->is_non_overlapping_and_dense_ = src_impl->is_non_overlapping_and_dense_;
|
|
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
|
dest_impl->reserved_ = src_impl->reserved_;
|
|
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
|
|
if (src_impl->named_tensor_meta_ != nullptr) {
|
|
dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone();
|
|
}
|
|
}
|
|
|
|
void TensorImpl::copy_tensor_metadata(
|
|
const TensorImpl* src_impl,
|
|
TensorImpl* dest_impl,
|
|
const c10::VariableVersion& version_counter,
|
|
bool allow_tensor_metadata_change) {
|
|
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
|
|
dest_impl->set_version_counter(version_counter);
|
|
}
|
|
|
|
void TensorImpl::copy_tensor_metadata(
|
|
const TensorImpl* src_impl,
|
|
TensorImpl* dest_impl,
|
|
c10::VariableVersion&& version_counter,
|
|
bool allow_tensor_metadata_change) {
|
|
copy_tensor_metadata_except_version_counter(src_impl, dest_impl, allow_tensor_metadata_change);
|
|
dest_impl->set_version_counter(std::move(version_counter));
|
|
}
|
|
|
|
namespace impl {
|
|
|
|
namespace {
|
|
AutogradMetaFactory* meta_factory = nullptr;
|
|
}
|
|
|
|
void SetAutogradMetaFactory(AutogradMetaFactory* factory) {
|
|
meta_factory = factory;
|
|
}
|
|
AutogradMetaFactory* GetAutogradMetaFactory() {
|
|
TORCH_CHECK(meta_factory, "Support for autograd has not been loaded; have you linked against libtorch.so?")
|
|
return meta_factory;
|
|
}
|
|
|
|
} // namespace impl
|
|
|
|
} // namespace c10
|