[functional autograd] Refactor validate_outputs into a functional variant (#141348)

Today, validate_outputs is stateful (it depends on the autograd graph).
This PR refactors it into a stateless form that just depends on
InputMetadata.

Test Plan:
- new unittest
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141348
Approved by: https://github.com/soulitzer
ghstack dependencies: #141278
This commit is contained in:
rzou 2024-12-02 13:28:53 -08:00 committed by PyTorch MergeBot
parent 2b4f1f4990
commit 215f5d77b5
3 changed files with 103 additions and 8 deletions

View File

@ -5,6 +5,7 @@
#include <torch/torch.h>
#include <torch/csrc/autograd/FunctionsManual.h>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <test/cpp/api/support.h>
@ -1668,6 +1669,36 @@ TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
}
static std::string test_format_error(const std::string& s) {
return s;
}
TEST(TestAutogradUtils, ValidateOutputsReduce) {
auto input = torch::ones({}, {torch::kFloat32});
auto grad = torch::ones({2, 3}, {torch::kFloat32});
std::vector<c10::optional<InputMetadata>> input_metadata;
input_metadata.emplace_back(InputMetadata(input));
std::vector<torch::Tensor> grads;
grads.emplace_back(grad);
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
ASSERT_TRUE(at::allclose(grads[0], grad.sum()));
}
TEST(TestAutogradUtils, ValidateOutputsBasic) {
auto input = torch::zeros({2, 3}, {torch::kFloat32});
auto grad = torch::ones({2, 3}, {torch::kFloat32});
std::vector<c10::optional<InputMetadata>> input_metadata;
input_metadata.emplace_back(InputMetadata(input));
std::vector<torch::Tensor> grads;
grads.emplace_back(grad);
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
ASSERT_TRUE(at::allclose(grad, torch::ones({2, 3})));
}
// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward

View File

@ -855,22 +855,68 @@ void set_device(int device) {
worker_device = device;
}
void validate_outputs(
const edge_list& edges,
// validate_outputs has two overloads, one that accepts edge_list and one that
// accepts vector<optional<InputMetadata>>. The former is stateful (it requires
// the autograd graph to actually use) and the latter is for functional
// autograd. (where we want to be able to take an autograd graph and then
// construct a FX graph out of it without specializing on the properties of the
// gradients).
//
// We do some templating to avoid dynamic allocations in the hot path (the eager
// autograd case). Otherwise, the problem is that we are given a vector<Edge>
// and would need to materialize a vector<optional<InputMetadata>> (or some
// other vector) to pass to a common helper function. The alternative is to use
// C++20's ranges which we don't have access to yet.
// Given an Edge or optional<InputMetdata>, return the InputMetadata
template <typename T>
const InputMetadata& get_input_metadata(const T& thing);
template <>
const InputMetadata& get_input_metadata<c10::optional<InputMetadata>>(
const c10::optional<InputMetadata>& thing) {
return thing.value();
}
template <>
const InputMetadata& get_input_metadata<Edge>(const Edge& thing) {
return thing.function->input_metadata(thing.input_nr);
}
// Given an Edge or optional<InputMetdata>, return if there is an InputMetadata.
template <typename T>
bool has_input_metadata(const T& thing);
template <>
bool has_input_metadata<c10::optional<InputMetadata>>(
const c10::optional<InputMetadata>& thing) {
return thing.has_value();
}
template <>
bool has_input_metadata<Edge>(const Edge& thing) {
return thing.is_valid();
}
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
// outputs. This involves using the InputMetadata to check the outputs and also
// potentially calling .sum_to on the outputs.
template <typename T>
void validate_outputs_impl(
const std::vector<T>& input_metadata_container,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error) {
if (grads.size() != edges.size()) {
if (grads.size() != input_metadata_container.size()) {
std::stringstream ss;
ss << "invalid number of gradients - expected ";
ss << edges.size() << ", but got " << grads.size();
ss << input_metadata_container.size() << ", but got " << grads.size();
TORCH_CHECK(false, format_error(ss.str()));
}
for (const auto i : c10::irange(grads.size())) {
const auto& edge = edges[i];
if (!edge.is_valid())
if (!has_input_metadata(input_metadata_container[i])) {
continue;
const auto& metadata = edge.function->input_metadata(edge.input_nr);
}
const auto& metadata = get_input_metadata(input_metadata_container[i]);
auto& grad = grads[i];
if (!grad.defined()) {
// FIXME: TestJit.test_ge_optimized fails this assertion.
@ -938,6 +984,20 @@ void validate_outputs(
}
}
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error) {
return validate_outputs_impl(edges, grads, format_error);
}
void validate_outputs(
const std::vector<c10::optional<InputMetadata>>& input_metadata,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error) {
return validate_outputs_impl(input_metadata, grads, format_error);
}
static variable_list call_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,

View File

@ -43,6 +43,10 @@ TORCH_API void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
TORCH_API void validate_outputs(
const std::vector<c10::optional<InputMetadata>>& input_metadata,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
struct NodeTask {
std::weak_ptr<GraphTask> base_;