pytorch/torch/csrc/autograd/custom_function.cpp
mal 3fa2df7c9a Support custom autograd functions in C++ (#23572)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572

### **(The stack from #23020  was moved into this PR)**

Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd).

The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument.

Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them.
Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`.

Example usage:
```
class MyFunction : public Function<MyFunction> {
  public:
  static variable_list forward(AutogradContext *ctx, int n, Variable var) {
     // Save data for backward in context
     ctx->saved_data["n"] = n;
     return {var};
  }

  static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
     // Use data saved in forward
     auto n = ctx->saved_data["n"].toInt();
     return {grad_output[0]*n};
  }
};

```
Then, it can be used with:
```
Variable x;
MyFunction::apply(6, x);
```

Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](ff23a02ac4/torch/autograd/function.py (L26)).

Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later.

Differential Revision: D16583428

fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 11:30:48 -07:00

156 lines
5.2 KiB
C++

#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
namespace torch { namespace autograd {
VariableInfo::VariableInfo(const Variable& var)
: backend(tensorTypeIdToBackend(var.type_id()))
, device(var.device())
, scalar_type(var.scalar_type())
, size(var.sizes().vec())
, requires_grad(var.requires_grad()) {
}
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
// NB: This will NOT work if we ever get mixed device gradients
device_guard.reset_device(device);
return at::zeros(size,
at::TensorOptions(scalar_type).device(backendToDeviceType(backend)).layout(layout_from_backend(backend)).is_variable(true));
}
variable_list _wrap_outputs(const variable_list &input_vars,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<Variable> raw_outputs,
const std::shared_ptr<Node> &cdata) {
std::unordered_set<at::TensorImpl*> inputs;
inputs.reserve(input_vars.size());
for (auto& var : input_vars) {
inputs.emplace(var.unsafeGetTensorImpl());
}
// Sets the grad_fn and output_nr of an output Variable.
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
bool is_differentiable) {
if (!is_differentiable) {
if (!var.requires_grad()) {
return;
}
// NB: we don't support returning non-differentiable views that could require grad
if (var.is_view()) {
throw std::runtime_error("Returning Variables sharing storage with other Variables "
"that require grad is not supported in Python functions. "
"Please submit a feature request if you hit this error.");
}
// Return detached aliases of inputs, instead of changing their requires_grad
// property.
if (is_input) {
var = var.detach();
} else {
var.detach_();
}
} else if (is_modified) {
if (var.is_leaf() && var.requires_grad()) {
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
}
// If the input was modified, transplant the grad_fn in the graph:
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
var.grad().reset();
var.clear_hooks();
if (auto grad_acc_fn = var.try_get_grad_accumulator()) {
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
grad_acc->variable.reset();
}
if (cdata) {
var.rebase_history({cdata, output_nr});
}
} else if (is_input) {
// An input has been returned, but it wasn't modified. Return it as a view
// so that we can attach a new grad_fn to the Variable.
var = var.view_as(var);
var.set_gradient_edge({cdata, output_nr});
} else if (cdata) {
var.set_gradient_edge({cdata, output_nr});
}
};
int num_outputs = raw_outputs.size();
std::vector<torch::autograd::Variable> outputs;
outputs.reserve(num_outputs);
for (auto i = 0; i < num_outputs; ++i) {
auto out_tensor_impl = raw_outputs[i].unsafeGetTensorImpl();
bool is_input = inputs.count(out_tensor_impl) > 0;
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0;
Variable var = raw_outputs[i];
if (cdata) {
auto output_nr = cdata->add_input_metadata(var);
AT_ASSERT(i == (int)output_nr);
}
set_history(var, i, is_input, is_modified, is_differentiable);
outputs.emplace_back(var);
}
return outputs;
}
void AutogradContext::save_for_backward(variable_list to_save) {
to_save_ = std::move(to_save);
}
// The logic for handling saved variables here is the same as python_function.cpp
// See _save_variables() and unpack_saved_variables()
void AutogradContext::save_variables() {
saved_variables_.clear();
auto ptr = grad_fn_.lock();
for (const auto& var : to_save_) {
bool is_output = var.grad_fn().get() == ptr.get();
saved_variables_.emplace_back(var, is_output);
}
to_save_.clear();
}
variable_list AutogradContext::get_saved_variables() const {
TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
variable_list saved;
saved.reserve(saved_variables_.size());
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
for (auto& var : saved_variables_) {
saved.push_back(var.unpack(ptr));
}
return saved;
}
void AutogradContext::mark_dirty(const variable_list &inputs) {
dirty_inputs_.clear();
dirty_inputs_.reserve(inputs.size());
for(auto& var : inputs) {
dirty_inputs_.insert(var.unsafeGetTensorImpl());
}
}
void AutogradContext::mark_non_differentiable(const variable_list &outputs) {
non_differentiable_.clear();
non_differentiable_.reserve(outputs.size());
for(auto& var : outputs) {
non_differentiable_.insert(var.unsafeGetTensorImpl());
}
}
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_dirty() const {
return dirty_inputs_;
}
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_non_differentiable() const {
return non_differentiable_;
}
}} // namespace torch::autograd