mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[functional autograd] Refactor validate_outputs into a functional variant (#141348)
Today, validate_outputs is stateful (it depends on the autograd graph). This PR refactors it into a stateless form that just depends on InputMetadata. Test Plan: - new unittest Pull Request resolved: https://github.com/pytorch/pytorch/pull/141348 Approved by: https://github.com/soulitzer ghstack dependencies: #141278
This commit is contained in:
parent
2b4f1f4990
commit
215f5d77b5
|
|
@ -5,6 +5,7 @@
|
|||
#include <torch/torch.h>
|
||||
|
||||
#include <torch/csrc/autograd/FunctionsManual.h>
|
||||
#include <torch/csrc/autograd/engine.h>
|
||||
#include <torch/csrc/autograd/functions/basic_ops.h>
|
||||
|
||||
#include <test/cpp/api/support.h>
|
||||
|
|
@ -1668,6 +1669,36 @@ TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
|
|||
ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
|
||||
}
|
||||
|
||||
static std::string test_format_error(const std::string& s) {
|
||||
return s;
|
||||
}
|
||||
|
||||
TEST(TestAutogradUtils, ValidateOutputsReduce) {
|
||||
auto input = torch::ones({}, {torch::kFloat32});
|
||||
auto grad = torch::ones({2, 3}, {torch::kFloat32});
|
||||
|
||||
std::vector<c10::optional<InputMetadata>> input_metadata;
|
||||
input_metadata.emplace_back(InputMetadata(input));
|
||||
std::vector<torch::Tensor> grads;
|
||||
grads.emplace_back(grad);
|
||||
|
||||
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
|
||||
ASSERT_TRUE(at::allclose(grads[0], grad.sum()));
|
||||
}
|
||||
|
||||
TEST(TestAutogradUtils, ValidateOutputsBasic) {
|
||||
auto input = torch::zeros({2, 3}, {torch::kFloat32});
|
||||
auto grad = torch::ones({2, 3}, {torch::kFloat32});
|
||||
|
||||
std::vector<c10::optional<InputMetadata>> input_metadata;
|
||||
input_metadata.emplace_back(InputMetadata(input));
|
||||
std::vector<torch::Tensor> grads;
|
||||
grads.emplace_back(grad);
|
||||
|
||||
torch::autograd::validate_outputs(input_metadata, grads, test_format_error);
|
||||
ASSERT_TRUE(at::allclose(grad, torch::ones({2, 3})));
|
||||
}
|
||||
|
||||
// TODO add these tests if needed
|
||||
// test_once_differentiable
|
||||
// test_sparse_backward
|
||||
|
|
|
|||
|
|
@ -855,22 +855,68 @@ void set_device(int device) {
|
|||
worker_device = device;
|
||||
}
|
||||
|
||||
void validate_outputs(
|
||||
const edge_list& edges,
|
||||
// validate_outputs has two overloads, one that accepts edge_list and one that
|
||||
// accepts vector<optional<InputMetadata>>. The former is stateful (it requires
|
||||
// the autograd graph to actually use) and the latter is for functional
|
||||
// autograd. (where we want to be able to take an autograd graph and then
|
||||
// construct a FX graph out of it without specializing on the properties of the
|
||||
// gradients).
|
||||
//
|
||||
// We do some templating to avoid dynamic allocations in the hot path (the eager
|
||||
// autograd case). Otherwise, the problem is that we are given a vector<Edge>
|
||||
// and would need to materialize a vector<optional<InputMetadata>> (or some
|
||||
// other vector) to pass to a common helper function. The alternative is to use
|
||||
// C++20's ranges which we don't have access to yet.
|
||||
|
||||
// Given an Edge or optional<InputMetdata>, return the InputMetadata
|
||||
template <typename T>
|
||||
const InputMetadata& get_input_metadata(const T& thing);
|
||||
|
||||
template <>
|
||||
const InputMetadata& get_input_metadata<c10::optional<InputMetadata>>(
|
||||
const c10::optional<InputMetadata>& thing) {
|
||||
return thing.value();
|
||||
}
|
||||
|
||||
template <>
|
||||
const InputMetadata& get_input_metadata<Edge>(const Edge& thing) {
|
||||
return thing.function->input_metadata(thing.input_nr);
|
||||
}
|
||||
|
||||
// Given an Edge or optional<InputMetdata>, return if there is an InputMetadata.
|
||||
template <typename T>
|
||||
bool has_input_metadata(const T& thing);
|
||||
|
||||
template <>
|
||||
bool has_input_metadata<c10::optional<InputMetadata>>(
|
||||
const c10::optional<InputMetadata>& thing) {
|
||||
return thing.has_value();
|
||||
}
|
||||
|
||||
template <>
|
||||
bool has_input_metadata<Edge>(const Edge& thing) {
|
||||
return thing.is_valid();
|
||||
}
|
||||
|
||||
// Given an vector<Edge> or vector<optional<InputMetdata>>, validate the
|
||||
// outputs. This involves using the InputMetadata to check the outputs and also
|
||||
// potentially calling .sum_to on the outputs.
|
||||
template <typename T>
|
||||
void validate_outputs_impl(
|
||||
const std::vector<T>& input_metadata_container,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error) {
|
||||
if (grads.size() != edges.size()) {
|
||||
if (grads.size() != input_metadata_container.size()) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid number of gradients - expected ";
|
||||
ss << edges.size() << ", but got " << grads.size();
|
||||
ss << input_metadata_container.size() << ", but got " << grads.size();
|
||||
TORCH_CHECK(false, format_error(ss.str()));
|
||||
}
|
||||
for (const auto i : c10::irange(grads.size())) {
|
||||
const auto& edge = edges[i];
|
||||
if (!edge.is_valid())
|
||||
if (!has_input_metadata(input_metadata_container[i])) {
|
||||
continue;
|
||||
|
||||
const auto& metadata = edge.function->input_metadata(edge.input_nr);
|
||||
}
|
||||
const auto& metadata = get_input_metadata(input_metadata_container[i]);
|
||||
auto& grad = grads[i];
|
||||
if (!grad.defined()) {
|
||||
// FIXME: TestJit.test_ge_optimized fails this assertion.
|
||||
|
|
@ -938,6 +984,20 @@ void validate_outputs(
|
|||
}
|
||||
}
|
||||
|
||||
void validate_outputs(
|
||||
const edge_list& edges,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error) {
|
||||
return validate_outputs_impl(edges, grads, format_error);
|
||||
}
|
||||
|
||||
void validate_outputs(
|
||||
const std::vector<c10::optional<InputMetadata>>& input_metadata,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error) {
|
||||
return validate_outputs_impl(input_metadata, grads, format_error);
|
||||
}
|
||||
|
||||
static variable_list call_function(
|
||||
std::shared_ptr<GraphTask>& graph_task,
|
||||
Node* func,
|
||||
|
|
|
|||
|
|
@ -43,6 +43,10 @@ TORCH_API void validate_outputs(
|
|||
const edge_list& edges,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
TORCH_API void validate_outputs(
|
||||
const std::vector<c10::optional<InputMetadata>>& input_metadata,
|
||||
variable_list& grads,
|
||||
const std::function<std::string(const std::string&)>& format_error);
|
||||
|
||||
struct NodeTask {
|
||||
std::weak_ptr<GraphTask> base_;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user