diff --git a/tools/autograd/gen_variable_factories.py b/tools/autograd/gen_variable_factories.py index b193fcc183a..fea70a334c9 100644 --- a/tools/autograd/gen_variable_factories.py +++ b/tools/autograd/gen_variable_factories.py @@ -13,7 +13,7 @@ inline at::Tensor ${name}(${formals}) { ${pre_record_trace} at::Tensor tensor = at::${name}(${actuals}); at::Tensor result = - autograd::make_variable(tensor, /*requires_grad=*/${requires_grad}); + autograd::make_variable_consuming(std::move(tensor), /*requires_grad=*/${requires_grad}); ${post_record_trace} return result; } diff --git a/torch/csrc/autograd/variable.h b/torch/csrc/autograd/variable.h index 7eea2b04246..e726f9cd9d5 100644 --- a/torch/csrc/autograd/variable.h +++ b/torch/csrc/autograd/variable.h @@ -104,7 +104,8 @@ struct TORCH_API Variable : public at::Tensor { bool allow_tensor_metadata_change, Edge gradient_edge); - /// Creates a `Variable` from the given `Tensor`. `requires_grad` should be + /// Creates a `Variable` from the given `Tensor`, copying its underlying `TensorImpl`. + /// `requires_grad` should be /// set only for leaves, and determines whether the `Variable` will accumulate /// gradients. NOTE: `data` must *not* be a `Variable` already. Its dynamic /// type *must* be `Tensor`. @@ -113,6 +114,16 @@ struct TORCH_API Variable : public at::Tensor { bool requires_grad, bool allow_tensor_metadata_change); + /// Creates a `Variable` from the given `Tensor`, consuming its underlying `TensorImpl`. + /// This is intended to be used from functions that immediately create a `Tensor`, + /// convert it to a `Variable`, and then free it; it has been found to + /// decrease the overhead of those operations, in some situations. + /// The comments about `requires_grad` and `data` on the above version also apply to this one. + friend Variable make_variable_consuming( + at::Tensor data, + bool requires_grad, + bool allow_tensor_metadata_change); + /// Creates a `Variable` from the given `Tensor` and specify a /// `gradient_edge`, i.e. a (function, input_nr) pair specifying the function /// in the autograd graph, and what particular input of that function, this @@ -576,6 +587,22 @@ inline Variable make_variable( return Variable(); } +inline Variable make_variable_consuming( + at::Tensor data, + bool requires_grad = false, + bool allow_tensor_metadata_change = true) { + AT_CHECK( + !data.is_variable(), + "Must not create a new variable from a variable, use its .data()"); + if (data.defined()) { + AT_ASSERT(data.getIntrusivePtr().use_count() == 1); + data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + auto autograd_meta = c10::guts::make_unique(); + return Variable(c10::make_intrusive(std::move(data), std::move(autograd_meta), requires_grad)); + } + return Variable(); +} + inline Variable make_variable( at::Tensor data, Edge gradient_edge,