mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.
```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
virtual ~ViewFunc() {}
/// Returns any SymInts in the saved state.
virtual std::vector<c10::SymInt> get_symints() const { return {}; }
/// Returns the number of SymInts in the saved state.
virtual size_t num_symints() const { return 0; }
/// Returns any tensors in the saved state.
virtual std::vector<at::Tensor> get_tensors() const { return {}; }
/// Returns the number of tensors in the saved state.
virtual size_t num_tensors() const { return 0; }
/// Reapplies the view on the given base using the saved state.
virtual at::Tensor operator()(const at::Tensor&) const = 0;
/// Returns a clone of this ViewFunc, optionally with the specified saved state.
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;
protected:
/// Sets the values of any SymInts in the saved state. The input vector size must
/// match the number of SymInts in the saved state (i.e. the size of the list
/// returned by get_symints()).
virtual void set_symints(std::vector<c10::SymInt>) {}
/// Sets the values of any Tensors in the saved state. The input vector size must
/// match the number of Tensors in the saved state (i.e. the size of the list
/// returned by get_tensors()).
virtual void set_tensors(std::vector<at::Tensor>) {}
};
```
New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`
The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.
Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
{};
virtual ~SliceTensorViewFunc() override {};
virtual std::vector<c10::SymInt> get_symints() const override;
virtual size_t num_symints() const override;
virtual std::vector<at::Tensor> get_tensors() const override;
virtual size_t num_tensors() const override;
virtual at::Tensor operator()(const at::Tensor&) const override;
virtual std::unique_ptr<ViewFunc> clone_and_set(
std::optional<std::vector<c10::SymInt>> = c10::nullopt,
std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;
protected:
virtual void set_symints(std::vector<c10::SymInt>) override;
virtual void set_tensors(std::vector<at::Tensor>) override;
private:
int64_t dim;
c10::optional<c10::SymInt> start;
c10::optional<c10::SymInt> end;
c10::SymInt step;
};
...
// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
::std::vector<c10::SymInt> symints;
symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
if(start.has_value()) symints.insert(symints.end(), *(start));
if(end.has_value()) symints.insert(symints.end(), *(end));
symints.push_back(step);
return symints;
}
size_t SliceTensorViewFunc::num_symints() const {
return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}
void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
auto i = 0;
if(start.has_value()) start = symints[i];
i += (start.has_value() ? 1 : 0);
if(end.has_value()) end = symints[i];
i += (end.has_value() ? 1 : 0);
step = symints[i];
}
std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
::std::vector<at::Tensor> tensors;
return tensors;
}
size_t SliceTensorViewFunc::num_tensors() const {
return static_cast<size_t>(0);
}
void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());
}
at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}
std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
std::optional<std::vector<c10::SymInt>> symints,
std::optional<std::vector<at::Tensor>> tensors) const {
auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
if (symints.has_value()) {
output->set_symints(std::move(*(symints)));
}
if (tensors.has_value()) {
output->set_tensors(std::move(*(tensors)));
}
return output;
}
```
The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.
For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
218 lines
7.2 KiB
C++
218 lines
7.2 KiB
C++
#include <torch/csrc/autograd/functions/tensor.h>
|
|
|
|
#include <torch/csrc/autograd/function.h>
|
|
#include <torch/csrc/autograd/functions/basic_ops.h>
|
|
#include <torch/csrc/autograd/functions/utils.h>
|
|
#include <torch/csrc/autograd/graph_task.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
#include <torch/csrc/dynamo/compiled_autograd.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace autograd {
|
|
|
|
auto CopyBackwards::apply(variable_list&& grads) -> variable_list {
|
|
check_input_variables("CopyBackwards", grads, 1, -1, true);
|
|
auto grad = c10::MaybeOwned<at::Tensor>::borrowed(grads[0]);
|
|
variable_list grad_inputs(2);
|
|
if (grad->defined()) {
|
|
if (task_should_compute_output(0)) {
|
|
grad_inputs[0] = at::zeros_like(*grad, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
|
}
|
|
if (task_should_compute_output(1)) {
|
|
// Handle R->C copies without raising a warning
|
|
const auto src_type = src_options.dtype().toScalarType();
|
|
if (!c10::isComplexType(src_type) && grad->is_complex()) {
|
|
grad = c10::MaybeOwned<at::Tensor>::owned(at::real(grads[0]));
|
|
}
|
|
|
|
at::DeviceGuard device_guard(src_options.device());
|
|
grad_inputs[1] = grad->to(src_options);
|
|
}
|
|
}
|
|
return grad_inputs;
|
|
}
|
|
|
|
void CopyBackwards::compiled_args(CompiledNodeArgs& args) {
|
|
args.collect(src_options);
|
|
}
|
|
variable_list CopyBackwards::apply_with_saved(
|
|
const variable_list& inputs,
|
|
SwapSavedVariables& saved) {
|
|
saved.before(src_options);
|
|
auto result = apply(variable_list(inputs));
|
|
saved.after(src_options);
|
|
return result;
|
|
}
|
|
|
|
CopySlices::CopySlices(
|
|
const Variable& base_var,
|
|
at::TensorGeometry view_,
|
|
std::unique_ptr<ViewFunc> view_fn_,
|
|
std::shared_ptr<Node> fn_)
|
|
: Node(),
|
|
base(base_var),
|
|
view(std::move(view_)),
|
|
view_fn(std::move(view_fn_)),
|
|
fn(std::move(fn_)) {
|
|
// Take the next_edges of fn as our own, except for index 0 which goes
|
|
// to base instead of the view.
|
|
add_input_metadata(base_var);
|
|
const auto num_outputs = fn->num_outputs();
|
|
next_edges_.reserve(num_outputs);
|
|
add_next_edge(impl::gradient_edge(base_var));
|
|
for (const auto i : c10::irange(1, num_outputs)) {
|
|
add_next_edge(fn->next_edge(i));
|
|
}
|
|
}
|
|
|
|
// common code between apply/apply_with_saved
|
|
template <typename T>
|
|
inline variable_list CopySlices::apply_impl(
|
|
variable_list&& inputs,
|
|
const T& call_fn) {
|
|
check_input_variables("CopySlices", inputs, 1, -1, true);
|
|
auto& grad = inputs[0];
|
|
if (!grad.defined()) {
|
|
return variable_list(num_outputs());
|
|
}
|
|
|
|
// Acquire lock to here protect thread safety on fn
|
|
// see Note [Thread Safety on Autograd Node]
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
|
|
if (!fn) {
|
|
throw std::runtime_error(ERR_BACKWARD_TWICE);
|
|
}
|
|
|
|
auto result =
|
|
grad.new_empty_strided_symint(base.sym_sizes(), base.sym_strides());
|
|
result.copy_(grad);
|
|
|
|
at::Tensor grad_slice;
|
|
if (view_fn) {
|
|
grad_slice = (*view_fn)(result);
|
|
} else {
|
|
auto offset = view.sym_storage_offset() - base.sym_storage_offset();
|
|
grad_slice =
|
|
result.as_strided_symint(view.sym_sizes(), view.sym_strides(), offset);
|
|
}
|
|
|
|
// See Note [View + Inplace update for view tensor] For more details on this
|
|
// block Since the gradient edge for the 0th input is different between `this`
|
|
// and `fn`, make sure that the one from `fn` has the same metadata in the
|
|
// current GraphTask's exec_info as the one on `this`.
|
|
const auto exec_info = get_current_graph_task_exec_info();
|
|
if (exec_info && !exec_info->empty()) {
|
|
const auto& fn_edge = fn->next_edge(0);
|
|
const auto& this_edge = this->next_edge(0);
|
|
TORCH_INTERNAL_ASSERT(fn_edge.is_valid() == this_edge.is_valid());
|
|
if (fn_edge.is_valid()) {
|
|
const auto fn_next_node = fn_edge.function.get();
|
|
auto it = exec_info->find(fn_next_node);
|
|
if (it == exec_info->end()) {
|
|
// Node is not in the exec_info already
|
|
if (task_should_compute_output(0)) {
|
|
// And we need gradient for the corresponding output
|
|
add_node_to_current_graph_task_exec_info(fn_next_node);
|
|
// There is no need to remove this after execution because we are
|
|
// guaranteed that this->next_edge(0) must be in the history of
|
|
// fn->next_edge(0) (we cannot easily assert this as it might be far
|
|
// away if there were many chained views). This means that, since
|
|
// fn->next_edge(0) was not needed (no exec_info entry for it), we
|
|
// know that nothing downstream of fn->next_edge(0) is needed either
|
|
// (otherwise the whole path from that Node to this->next_edge(0)
|
|
// would be needed as well). This means that no other Node will ever
|
|
// look at fn->next_edge(0) metadata and thus there is no need to
|
|
// clean them up.
|
|
}
|
|
} else {
|
|
TORCH_INTERNAL_ASSERT(
|
|
it->second.should_execute() == task_should_compute_output(0));
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sanity check that the graph was never modified after the fact (it is
|
|
// read-only!)
|
|
TORCH_INTERNAL_ASSERT(num_outputs() == fn->num_outputs());
|
|
for (const auto i : c10::irange(1, this->num_outputs())) {
|
|
TORCH_INTERNAL_ASSERT(
|
|
fn->next_edge(i).function.get() == this->next_edge(i).function.get());
|
|
}
|
|
|
|
// TODO: We clone grad_slice because we modify it below and "fn" might save
|
|
// it for the backward of res. We might be able to avoid the clone() if
|
|
// double-backprop is disabled.
|
|
auto res = call_fn({grad_slice.clone(at::MemoryFormat::Contiguous)});
|
|
|
|
variable_list grad_inputs(num_outputs());
|
|
for (const auto i : c10::irange(res.size())) {
|
|
if (task_should_compute_output(i)) {
|
|
if (!res[i].defined()) {
|
|
// If the output is not defined, treat it as if it was a zero tensor.
|
|
// This can happen if users define a custom Function.
|
|
continue;
|
|
}
|
|
if (i == 0) {
|
|
grad_slice.copy_(res[i]);
|
|
// NOLINTNEXTLINE(clang-analyzer-cplusplus.Move)
|
|
grad_inputs[i] = std::move(result); // NOLINT(bugprone-use-after-move)
|
|
} else {
|
|
grad_inputs[i] = std::move(res[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
return grad_inputs;
|
|
}
|
|
|
|
void CopySlices::release_variables() {
|
|
// Acquire lock to here protect thread safety on fn
|
|
std::lock_guard<std::mutex> lock(mutex_);
|
|
fn = nullptr;
|
|
}
|
|
|
|
void CopySlices::compiled_args(CompiledNodeArgs& args) {
|
|
TORCH_CHECK(!view_fn, "view_fn not supported by compiled autograd")
|
|
TORCH_INTERNAL_ASSERT((bool)fn);
|
|
args.collect(base);
|
|
args.collect(view);
|
|
args.collect(fn);
|
|
fn->compiled_args(args);
|
|
}
|
|
|
|
variable_list CopySlices::apply_with_saved(
|
|
const variable_list& grads,
|
|
SwapSavedVariables& saved) {
|
|
saved.before(base);
|
|
saved.before(view);
|
|
int call_count = 0;
|
|
variable_list result = apply_impl(
|
|
variable_list(grads),
|
|
[this, &saved, &call_count](const variable_list& inputs2) {
|
|
call_count++;
|
|
return fn->apply_with_saved(inputs2, saved);
|
|
});
|
|
TORCH_INTERNAL_ASSERT(call_count == 1);
|
|
saved.after(base);
|
|
saved.after(view);
|
|
return result;
|
|
}
|
|
|
|
auto CopySlices::apply(variable_list&& inputs1) -> variable_list {
|
|
return apply_impl(std::move(inputs1), [this](variable_list&& inputs2) {
|
|
return (*fn)(std::move(inputs2));
|
|
});
|
|
}
|
|
|
|
} // namespace autograd
|
|
} // namespace torch
|