pytorch/c10/core/TensorImpl.cpp
Edward Yang aacc722aec Dispatch to Python via __torch_dispatch__ (#59760)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59760

See https://github.com/pytorch/pytorch/issues/59049

There are some moving parts to this PR, I'll structure this explanation so the straightforward parts go first, and then the less straightforward parts.

**The actual dispatch to Python.** The core logic of dispatch to Python lives in `concrete_dispatch_fn` in `torch/csrc/autograd/python_variable.cpp`. It takes the input IValue stack, scans all the arguments for Tensor arguments, and defers most of the heavy lifting to `handle_torch_function_no_python_arg_parser` which actually does all of the logic for calling out to torch dispatch (in particular, this function handles multiple dispatch situations for you). Because we have a different function name than regular `__torch_function__` handling, `handle_torch_function_no_python_arg_parser` is generalized to accept a magic method name to look for when testing if Tensors have custom handling or not. Unlike `__torch_function__`, by default there is no `__torch_dispatch__` on Tensor classes.

**Maintaining the Python dispatch key.** In order to get to the dispatch to Python logic, we must tag Tensors with the `__torch_dispatch__` magic method with the newly added Python dispatch key (separated from PythonFuncTorch to allow for a transitional period while they migrate to this mechanism). We expose a new private property `_is_python_dispatch` that assists in debugging if a Tensor is participating in Python dispatch or not. We apply the Python dispatch key the first time a PyObject for a Tensor is constructed (THPVariable_NewWithVar), testing if `__torch_dispatch__` exists with  then newly added `check_has_torch_dispatch`.

**Shallow copy and detach.** For the simple examples tested in this PR, most creations of Tensor route through the dispatcher. The exception to this is `shallow_copy_and_detach`, which bypasses the dispatcher and is used when saving tensors for backwards. When a Tensor is Python dispatch, we override the behavior of `shallow_copy_and_detach` to instead directly call into `__torch_dispatch__` to perform a `detach` operation (in the same way it would be invoked if you called `detach` directly). Because this Python call is triggered directly from c10::TensorImpl, it must be indirected through `PyInterpreter::detach`, which is the general mechanism for dynamic dispatching to the Python interpreter associated with a TensorImpl.

**torchdeploy compatibility.** The dispatch to Python logic cannot be directly registered to the dispatcher as it is compiled in the Python library, which will get loaded multiple times per torchdeploy interpreter. Thus, we must employ a two phase process. First, we register a fallback inside a non-Python library (aten/src/ATen/core/PythonFallbackKernel.cpp). Its job is to determine the appropriate PyInterpreter to handle the Python dispatch by going through all of the arguments and finding the first argument that has a PyObject/PyInterpreter. With this PyInterpreter, it makes another dynamic dispatch via "dispatch" which will go to the correct torchdeploy interpreter to handle dispatching to actual Python.

**Testing.** We provide a simple example of a LoggingTensor for testing, which can be used to generate TorchScript-like traces to observe what operations are being called when a Tensor is invoked. Although a LoggingTensor would be better implemented via an is-a relationship rather than a has-a relationship (as is done in the test), we've done it this way to show that arbitrarily complex compositions of tensors inside a tensor work properly.

**Known limitations.**

* We haven't adjusted any operator code, so some patterns may not work (as they lose the Python subclass in an unrecoverable way)
* `__torch_function__` must be explicitly disabled with `_disabled_torch_function_impl` otherwise things don't work quite correctly (in particular, what is being disabled is default subclass preservation behavior.)
* We don't ever populate kwargs, even when an argument is kwarg-only

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision:
D29017912
D29017912

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Pulled By: ezyang

fbshipit-source-id: a67714d9e541d09203a8cfc85345b8967db86238
2021-06-25 11:50:32 -07:00

634 lines
21 KiB
C++

#include <c10/core/TensorImpl.h>
#include <c10/core/Backend.h>
#include <c10/core/InferenceMode.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/core/impl/LocalDispatchKeySet.h>
#include <c10/util/Optional.h>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_bool(
caffe2_keep_on_shrink,
true,
"If set, keeps memory when a tensor is shrinking its size.");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_int64(
caffe2_max_keep_on_shrink_memory,
LLONG_MAX,
"The maximum memory in bytes to keep on shrink, if the difference between "
"tensor sizes is bigger than this then tensor will be reset.");
namespace c10 {
namespace impl {
static std::string noop_name_fn(const PyInterpreter*) {
return "<unloaded interpreter>";
}
static void noop_decref_fn(const PyInterpreter*, PyObject*) {
// no-op
}
static c10::intrusive_ptr<TensorImpl> noop_detach_fn(
const PyInterpreter*,
const TensorImpl*) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to detach (shallow_copy_and_detach) Tensor with nontrivial PyObject after corresponding interpreter died");
}
static void noop_dispatch_fn(
const PyInterpreter*,
const c10::OperatorHandle& op,
torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(
0,
"attempted to dispatch (__torch_dispatch__) an operator on Tensor with nontrivial PyObject after corresponding interpreter died");
}
void PyInterpreter::disarm() noexcept {
name_fn_ = &noop_name_fn;
decref_fn_ = &noop_decref_fn;
detach_fn_ = &noop_detach_fn;
dispatch_fn_ = &noop_dispatch_fn;
}
} // namespace impl
const char* const TensorImpl::err_msg_tensor_metadata_change_not_allowed =
"is not allowed on a Tensor created from .data or .detach().\n"
"If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)\n"
"without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.\n"
"For example, change:\n"
" x.data.set_(y)\n"
"to:\n"
" with torch.no_grad():\n"
" x.set_(y)";
at::Tensor& TensorImpl::mutable_grad() {
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
return autograd_meta_->mutable_grad();
}
const at::Tensor& TensorImpl::grad() const {
// Yes, I know this looks really weird. But I don't really have a choice as
// long as this function returns a const reference to Tensor. I'm not
// really sure how I would have designed this API differently, but it
// is not so easy to fix right now because the mutable counterpart of
// this function must keep working so that "x.grad() = ..." keeps working
// (part of public API).
if (!autograd_meta_)
return impl::GetAutogradMetaFactory()->undefined_tensor();
return autograd_meta_->grad();
}
const at::Tensor& TensorImpl::_fw_grad(uint64_t level, const at::Tensor& self)
const {
// See TensorImpl::grad() above for explanation about the line below
if (!autograd_meta_)
return impl::GetAutogradMetaFactory()->undefined_tensor();
return autograd_meta_->fw_grad(level, self);
}
void TensorImpl::_set_fw_grad(
const at::Tensor& new_grad,
const at::Tensor& self,
uint64_t level,
bool is_inplace_op) {
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
autograd_meta_->set_fw_grad(new_grad, self, level, is_inplace_op);
}
TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
// Use std::forward to suppress static analyzer false positive.
: TensorImpl(
std::forward<Storage>(storage),
key_set,
data_type,
storage.device()) {}
// [Note: Python key removal]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In most constructors for TensorImpl, you will see Python key is removed from
// the passed in DispatchKeySet. Why?
//
// INVARIANT: Python dispatch key is set iff PyObject for the Tensor has a
// nontrivial __torch_dispatch__ implementation.
//
// When a fresh TensorImpl is created, there is *no* PyObject (this only gets
// initialized lazily at the first point in time the Tensor passes into Python).
// So we would violate the invariant.
//
// In practice, what will happen shortly afterwards is that the TensorImpl
// will get its PyObject initialized by Tensor._make_subclass; at this point
// the Python dispatch key will be set and all is well. The point is to delay
// the dispatch key setting until that point.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
ImplType type,
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type)
: storage_(std::move(storage)),
pyobj_interpreter_(nullptr),
pyobj_(nullptr),
storage_offset_(0),
numel_(0),
data_type_(data_type),
device_opt_(storage_.device()),
key_set_(key_set.remove(
DispatchKey::Python)) { // See [Note: Python key removal]
init_bitfields();
// Inference tensor doesn't have version counter.
if (!is_inference()) {
version_counter_ = VariableVersion(/*version=*/0);
}
}
TensorImpl::TensorImpl(
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
// NOLINTNEXTLINE(performance-move-const-arg)
: TensorImpl({}, key_set, data_type, std::move(device_opt)) {}
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TensorImpl::TensorImpl(
Storage&& storage,
DispatchKeySet key_set,
const caffe2::TypeMeta data_type,
c10::optional<c10::Device> device_opt)
: storage_(std::move(storage)),
pyobj_interpreter_(nullptr),
pyobj_(nullptr),
storage_offset_(0),
numel_(0),
data_type_(data_type),
device_opt_(device_opt) {
init_bitfields();
if (!key_set.empty()) {
TORCH_INTERNAL_ASSERT(
data_type == ScalarType::Undefined || device_opt_.has_value());
// UndefinedTensorImpl is a singleton, so we skip logging it
C10_LOG_API_USAGE_ONCE("tensor.create");
}
bool inference_mode = c10::InferenceMode::is_enabled();
// TODO: be more explicit about the full key set at call sites so we
// don't have to keep recomputing it here
DispatchKey k = key_set.highestPriorityBackendTypeId();
key_set = key_set | getAutocastRelatedKeySetFromBackend(k);
key_set =
key_set.remove(DispatchKey::Python); // See [Note: Python key removal]
// Inference tensor doesn't have autograd related keys.
if (inference_mode) {
// See Note [Expected TLS state in InferenceMode] for why we exclude
// Autograd & ADInplaceOrView keys. Normally key_set only contains backend
// keys but we do the substraction here to make sure.
key_set_ = key_set - c10::autograd_dispatch_keyset_with_ADInplaceOrView;
} else {
// TODO: Ideally we only add AutogradBackend key when the tensor requires
// grad.
// See Note [Dream: skip VariableType kernel when requires_grad=false]
key_set_ = key_set | getAutogradRelatedKeySetFromBackend(k);
}
// Inference tensor doesn't have version counter.
if (!is_inference()) {
version_counter_ = VariableVersion(/*version=*/0);
}
// we would also like to check that non-cpu devices have an index, but some
// Caffe2 operators create Storages with default devices.
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef TensorImpl::sizes() const {
return sizes_and_strides_.sizes_arrayref();
}
#endif
IntArrayRef TensorImpl::strides() const {
return sizes_and_strides_.strides_arrayref();
}
void TensorImpl::HandleResize() {
// If needed, we will free the data. the next mutable_data() call
// will create the data storage.
bool reset_tensor = false;
if (reserved_) {
// If tensor is reserved then don't claim its memeory unless nbytes()
// is smaller than new size
reset_tensor =
storage_.nbytes() < (storage_offset_ + numel_) * data_type_.itemsize();
} else {
reset_tensor = storage_.nbytes() <
(storage_offset_ + numel_) * data_type_.itemsize() ||
!FLAGS_caffe2_keep_on_shrink ||
storage_.nbytes() - (storage_offset_ + numel_) * data_type_.itemsize() >
static_cast<size_t>(FLAGS_caffe2_max_keep_on_shrink_memory);
}
if (reset_tensor && storage_initialized()) {
FreeMemory();
}
}
bool TensorImpl::compute_contiguous() const {
bool is_contiguous = true;
if (is_empty())
return is_contiguous;
int64_t z = 1;
for (int64_t d = dim() - 1; d >= 0; d--) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) == z) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
}
}
return is_contiguous;
}
bool TensorImpl::compute_channels_last_contiguous_2d() const {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
case 4: {
int64_t expected = 1;
for (auto& d : {1, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 3:
// TODO dim == 3 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
bool TensorImpl::compute_channels_last_contiguous_3d() const {
// Please don't combine these code, constant array is used here to let
// compiler fully unroll the loop to get better performance
switch (sizes_and_strides_.size()) {
case 5: {
int64_t expected = 1;
for (auto& d : {1, 4, 3, 2, 0}) {
const auto size_d = sizes_and_strides_.size_at_unchecked(d);
if (size_d != 1) {
if (sizes_and_strides_.stride_at_unchecked(d) != expected) {
return false;
}
expected *= size_d;
}
}
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone)
case 4:
// TODO dim == 4 case will be enabled once it is fully tested
return false;
default:
return false;
}
}
bool TensorImpl::compute_strides_like_channels_last_2d() const {
return is_channels_last_strides_2d(
TensorImpl::sizes(), TensorImpl::strides());
}
bool TensorImpl::compute_strides_like_channels_last_3d() const {
return is_channels_last_strides_3d(
TensorImpl::sizes(), TensorImpl::strides());
}
bool TensorImpl::compute_non_overlapping_and_dense() const {
if (dim() == 1) {
return sizes_and_strides_.size_at_unchecked(0) < 2 ||
sizes_and_strides_.stride_at_unchecked(0) == 1;
}
SmallVector<int64_t, 5> perm;
perm.resize(dim());
for (int64_t i = 0; i < dim(); i++) {
perm[i] = i;
}
// Sort by strides, leaving 0 and 1 sized dims at the end of the array
std::sort(perm.begin(), perm.end(), [&](int64_t a, int64_t b) {
if (sizes_and_strides_.size_at_unchecked(a) < 2) {
return false;
} else if (sizes_and_strides_.size_at_unchecked(b) < 2) {
return true;
}
return sizes_and_strides_.stride_at_unchecked(a) <
sizes_and_strides_.stride_at_unchecked(b);
});
auto require_stride = 1;
for (int64_t i = 0; i < dim(); i++) {
const auto size_perm_i = sizes_and_strides_.size_at_unchecked(perm[i]);
if (size_perm_i < 2) {
return true;
}
if (sizes_and_strides_.stride_at_unchecked(perm[i]) != require_stride) {
return false;
}
require_stride *= size_perm_i;
}
return true;
}
void TensorImpl::release_resources() {
autograd_meta_.reset();
if (storage_) {
storage_ = {};
}
if (owns_pyobj_) {
TORCH_INTERNAL_ASSERT(pyobj_interpreter_ != nullptr);
TORCH_INTERNAL_ASSERT(pyobj_ != nullptr);
pyobj_interpreter_.load(std::memory_order_acquire)->decref(pyobj_);
// NB: this destructor can only be entered when there are no
// references to this C++ object (obviously), NOR any references
// to the PyObject (if there are references to the PyObject,
// then the PyObject holds an owning reference to the tensor).
// So it is OK to clear pyobj_ here as it is impossible for it to
// be used again (modulo weak reference races)
pyobj_ = nullptr; // for safety
}
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t TensorImpl::dim() const {
return sizes_and_strides_.size();
}
#endif
int64_t TensorImpl::size(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
return sizes_and_strides_.size_at_unchecked(d);
}
int64_t TensorImpl::stride(int64_t d) const {
d = at::maybe_wrap_dim(d, dim(), false);
return sizes_and_strides_.stride_at_unchecked(d);
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
bool TensorImpl::has_storage() const {
return storage_;
}
#endif
void TensorImpl::throw_storage_access_error() const {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "Cannot access storage of ", tensorimpl_type_name());
}
bool TensorImpl::is_contiguous_nondefault_policy_impl(
at::MemoryFormat memory_format) const {
if (has_contiguity_ ==
static_cast<uint8_t>(HasContiguityPolicy::ContiguityNotSupported)) {
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"Tensors of type ",
tensorimpl_type_name(),
" do not have is_contiguous");
} else {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
has_contiguity_ ==
static_cast<uint8_t>(HasContiguityPolicy::CustomBehavior));
return is_contiguous_custom(memory_format);
}
}
bool TensorImpl::is_contiguous_custom(at::MemoryFormat memory_format) const {
TORCH_INTERNAL_ASSERT(
false,
"TensorImpl::is_contiguous_custom should never be called; did you "
"set_has_contiguity_policy and forget to override is_contiguous_custom?");
}
static void deletePlacementDeleteContext(void* ptr) {
delete static_cast<PlacementDeleteContext*>(ptr);
}
at::DataPtr PlacementDeleteContext::makeDataPtr(
at::DataPtr&& data_ptr,
PlacementDtor placement_dtor,
size_t size,
at::Device device) {
auto* ptr = data_ptr.get();
return {
ptr,
new PlacementDeleteContext(std::move(data_ptr), placement_dtor, size),
&deletePlacementDeleteContext,
device};
}
// NOLINTNEXTLINE(modernize-use-equals-default)
AutogradMetaInterface::~AutogradMetaInterface() {}
// Setting requires_grad to true on inference tensor outside InferenceMode
// is forbidden. Ideally it would also be illegal inside InferenceMode.
// But there's no way that we can directly allocate a tensor to have
// requires_grad = true in C++ constructor so set_requires_grad is widely
// used in C++ frontend. Forbidding it inside InferenceMode will force users
// to delete these setter code in their code which is not ideal.
void TensorImpl::set_requires_grad(bool requires_grad) {
TORCH_CHECK(
!(requires_grad && is_inference() && !c10::InferenceMode::is_enabled()),
"Setting requires_grad=True on inference tensor outside InferenceMode is not allowed.");
if (!requires_grad && !autograd_meta_)
return;
if (!autograd_meta_)
autograd_meta_ = impl::GetAutogradMetaFactory()->make();
// NB: In principle, setting requires_grad to false could result in
// the AutogradMeta becoming equal to a default constructed state,
// in which case we could apply the nullptr AutogradMeta optimization
// (see autograd_meta_ docs). But we don't do this right now. Note
// that it is unsound to unconditionally set AutogradMeta to false
// when you set requires_grad to False, as there may be nontrivial
// information content in the other fields; for example, we may
// have set the string name for a Variable, or there may be hooks
// registered for it.
autograd_meta_->set_requires_grad(requires_grad, this);
}
bool TensorImpl::requires_grad() const {
if (!autograd_meta_)
return false;
return autograd_meta_->requires_grad();
}
void TensorImpl::set_autograd_meta(
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
// NB: autograd_meta may be null! That just means it's the default
// constructor
autograd_meta_ = std::move(autograd_meta);
}
c10::AutogradMetaInterface* TensorImpl::autograd_meta() const {
// NB: Might return null!
return autograd_meta_.get();
}
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const {
if (key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
if (r) {
r->set_version_counter(version_counter);
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
// otherwise just copy the TensorImpl and not the PyObject. Since
// the interpreter is dead no one can call us out on it
}
auto impl = c10::make_intrusive<TensorImpl>(
// No need to populate Storage; copy_tensor_metadata will do it for us.
key_set_,
data_type_,
device_opt_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/version_counter,
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
c10::intrusive_ptr<TensorImpl> TensorImpl::shallow_copy_and_detach(
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const {
if (key_set_.has(DispatchKey::Python) &&
!c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
auto r = pyobj_interpreter_.load(std::memory_order_acquire)->detach(this);
if (r) {
r->set_version_counter(std::move(version_counter));
r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
return r;
}
// otherwise just copy the TensorImpl and not the PyObject. Since
// the interpreter is dead no one can call us out on it
}
auto impl = c10::make_intrusive<TensorImpl>(
// No need to populate Storage; copy_tensor_metadata will do it for us.
key_set_,
data_type_,
device_opt_);
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
/*version_counter=*/std::move(version_counter),
/*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
impl->refresh_numel();
impl->refresh_contiguous();
return impl;
}
void TensorImpl::copy_tensor_metadata_except_version_counter(
const TensorImpl* src_impl,
TensorImpl* dest_impl,
bool allow_tensor_metadata_change) {
dest_impl->storage_ = src_impl->storage_;
dest_impl->sizes_and_strides_ = src_impl->sizes_and_strides_;
dest_impl->storage_offset_ = src_impl->storage_offset_;
dest_impl->data_type_ = src_impl->data_type_;
dest_impl->device_opt_ = src_impl->device_opt_;
dest_impl->key_set_ = src_impl->key_set_.remove(DispatchKey::Python);
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->has_contiguity_ = src_impl->has_contiguity_;
dest_impl->is_channels_last_contiguous_ =
src_impl->is_channels_last_contiguous_;
dest_impl->is_channels_last_3d_contiguous_ =
src_impl->is_channels_last_3d_contiguous_;
dest_impl->is_channels_last_ = src_impl->is_channels_last_;
dest_impl->is_channels_last_3d_ = src_impl->is_channels_last_3d_;
dest_impl->is_non_overlapping_and_dense_ =
src_impl->is_non_overlapping_and_dense_;
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
dest_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
dest_impl->storage_access_should_throw_ =
src_impl->storage_access_should_throw_;
if (src_impl->named_tensor_meta_ != nullptr) {
dest_impl->named_tensor_meta_ = src_impl->named_tensor_meta_->clone();
}
}
void TensorImpl::copy_tensor_metadata(
const TensorImpl* src_impl,
TensorImpl* dest_impl,
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(
src_impl, dest_impl, allow_tensor_metadata_change);
// TODO: In the ideal end state, it's okay to set disabled version_counter
// on inference tensor since it's a no-op. This requires refactor on call
// sites.
if (!dest_impl->is_inference()) {
dest_impl->set_version_counter(version_counter);
}
}
void TensorImpl::copy_tensor_metadata(
const TensorImpl* src_impl,
TensorImpl* dest_impl,
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) {
copy_tensor_metadata_except_version_counter(
src_impl, dest_impl, allow_tensor_metadata_change);
if (!dest_impl->is_inference()) {
dest_impl->set_version_counter(std::move(version_counter));
}
}
namespace impl {
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
AutogradMetaFactory* meta_factory = nullptr;
} // namespace
void SetAutogradMetaFactory(AutogradMetaFactory* factory) {
meta_factory = factory;
}
AutogradMetaFactory* GetAutogradMetaFactory() {
TORCH_CHECK(
meta_factory,
"Support for autograd has not been loaded; have you linked against libtorch.so?")
return meta_factory;
}
} // namespace impl
} // namespace c10