mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69835
`StaticRuntimeBlockRunner` moves its outputs to the return value at the end of `run_impl`. However, there's a corner case where this can cause problems. If we return a constant, then the only reference in the `constants_` array can be destroyed by this move. We could add special logic to handle this in `run_impl`. But since this is a relatively rare corner case, it's simpler to just add an op that does nothing but create an owned reference to its input. This owned reference can be safely moved out of `StaticRuntimeBlockRunner`.
Note that this also applies to returned values in sub-blocks that are from outer scopes.
ghstack-source-id: 148186452
Test Plan:
`buck test caffe2/benchmarks/static_runtime/...`
Added a new unit test with a graph that simply returns a constant.
Tests with sub-blocks at top of stack.
Reviewed By: d1jang
Differential Revision: D33047519
fbshipit-source-id: 22b6058f0d1da8a6d1d61a6f2866bc518bff482b
(cherry picked from commit a8f89a12ee)
406 lines
13 KiB
C++
406 lines
13 KiB
C++
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
|
|
|
|
#include "test_utils.h"
|
|
|
|
#include <ATen/core/ivalue.h>
|
|
#include <gtest/gtest.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
|
#include <torch/csrc/jit/runtime/static/impl.h>
|
|
#include <torch/csrc/jit/runtime/static/memory_planner.h>
|
|
#include <torch/csrc/jit/runtime/static/passes.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#else
|
|
#include <ATen/ops/allclose.h>
|
|
#endif
|
|
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
|
|
using namespace torch::jit;
|
|
using namespace torch;
|
|
using c10::IValue;
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace test {
|
|
|
|
namespace {
|
|
|
|
class GraphExecutorWrapper {
|
|
public:
|
|
GraphExecutorWrapper() = default;
|
|
|
|
explicit GraphExecutorWrapper(const std::shared_ptr<Graph>& graph)
|
|
: graph_exec_(graph, "") {}
|
|
|
|
c10::IValue operator()(const std::vector<c10::IValue>& args) {
|
|
Stack stack(args);
|
|
graph_exec_.run(stack);
|
|
|
|
if (stack.size() == 1) {
|
|
return stack[0];
|
|
}
|
|
return c10::ivalue::Tuple::create(stack);
|
|
}
|
|
|
|
private:
|
|
GraphExecutor graph_exec_;
|
|
};
|
|
|
|
// Test scripts passed to testStaticRuntime can either be IR or JIT.
|
|
// The logic for running the script and producing a corresponding StaticModule
|
|
// is a bit different for each case. This logic is encapsulated within concrete
|
|
// implementations of this class, and testStaticRuntime is only aware of this
|
|
// interface.
|
|
class StaticRuntimeTestContext {
|
|
public:
|
|
virtual ~StaticRuntimeTestContext() = default;
|
|
|
|
virtual IValue getExpected(const std::vector<IValue>& args) = 0;
|
|
virtual StaticModule makeStaticModule(
|
|
const StaticModuleOptions& opt) const = 0;
|
|
};
|
|
|
|
class ModuleStaticRuntimeTestContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit ModuleStaticRuntimeTestContext(const std::string& source_jit)
|
|
: module_("module") {
|
|
module_.define(source_jit);
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
return module_.forward(args);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return torch::jit::StaticModule(
|
|
module_, /* is_frozen */ false, opt, /* sample_inputs */ {});
|
|
}
|
|
|
|
private:
|
|
Module module_;
|
|
};
|
|
|
|
class GraphStaticRuntimeContext : public StaticRuntimeTestContext {
|
|
public:
|
|
explicit GraphStaticRuntimeContext(const std::string& source_ir) {
|
|
graph_ = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(source_ir, graph_.get(), vmap);
|
|
|
|
graph_exec_ = GraphExecutorWrapper(graph_);
|
|
}
|
|
|
|
IValue getExpected(const std::vector<IValue>& args) override {
|
|
return graph_exec_(args);
|
|
}
|
|
|
|
StaticModule makeStaticModule(const StaticModuleOptions& opt) const override {
|
|
return StaticModule(graph_, opt, /* sample_inputs */ {});
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<Graph> graph_;
|
|
GraphExecutorWrapper graph_exec_;
|
|
};
|
|
|
|
std::unique_ptr<StaticRuntimeTestContext> makeTestContext(
|
|
const std::string& source) {
|
|
try {
|
|
return std::make_unique<ModuleStaticRuntimeTestContext>(source);
|
|
// Could not parse as TorchScript, assume it's IR
|
|
} catch (const std::runtime_error&) {
|
|
return std::make_unique<GraphStaticRuntimeContext>(source);
|
|
}
|
|
}
|
|
|
|
void compareTensorLists(
|
|
const std::vector<IValue>& l, /* expects */
|
|
const std::vector<IValue>& r, /* values */
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
EXPECT_TRUE(l.size() == r.size());
|
|
for (int i = 0; i < l.size(); ++i) {
|
|
ASSERT_TRUE(l[i].isTensor());
|
|
ASSERT_TRUE(r[i].isTensor());
|
|
VLOG(2) << "expect " << i << ": \n" << l[i] << std::endl;
|
|
VLOG(2) << "output " << i << ": \n" << r[i] << std::endl;
|
|
if (!l[i].toTensor().defined()) {
|
|
EXPECT_TRUE(!r[i].toTensor().defined());
|
|
} else {
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
l[i].toTensor(),
|
|
r[i].toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(l[i].toTensor().equal(r[i].toTensor()));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void compareResults(
|
|
const IValue& expect,
|
|
const IValue& actual,
|
|
const bool use_allclose = false,
|
|
const bool use_equalnan = false) {
|
|
if (expect.isTensor()) {
|
|
VLOG(2) << "expect " << expect.toTensor() << std::endl;
|
|
VLOG(2) << "output " << actual.toTensor() << std::endl;
|
|
EXPECT_TRUE(actual.isTensor());
|
|
if (use_allclose) {
|
|
EXPECT_TRUE(at::allclose(
|
|
expect.toTensor(),
|
|
actual.toTensor(),
|
|
/*rtol*/ 1e-05,
|
|
/*atol*/ 1e-08,
|
|
use_equalnan));
|
|
} else {
|
|
EXPECT_TRUE(expect.toTensor().equal(actual.toTensor()));
|
|
}
|
|
return;
|
|
} else if (expect.isTuple()) {
|
|
EXPECT_TRUE(actual.isTuple());
|
|
auto lhs = expect.toTupleRef().elements();
|
|
auto rhs = actual.toTupleRef().elements();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isList()) {
|
|
EXPECT_TRUE(actual.isList());
|
|
auto lhs = expect.toList();
|
|
auto rhs = actual.toList();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (size_t i = 0; i < lhs.size(); i++) {
|
|
compareResults(lhs[i], rhs[i]);
|
|
}
|
|
} else if (expect.isGenericDict()) {
|
|
EXPECT_TRUE(actual.isGenericDict());
|
|
auto lhs = expect.toGenericDict();
|
|
auto rhs = actual.toGenericDict();
|
|
EXPECT_TRUE(lhs.size() == rhs.size());
|
|
for (auto& lh : lhs) {
|
|
auto f = rhs.find(lh.key());
|
|
EXPECT_FALSE(f == rhs.end());
|
|
compareResults(lh.value(), f->value());
|
|
}
|
|
} else {
|
|
// fall back to the default comparison impl in IValue
|
|
EXPECT_TRUE(expect == actual);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
at::Tensor getTensor(const at::IValue& ival) {
|
|
if (ival.isTensor()) {
|
|
return ival.toTensor();
|
|
} else if (ival.isTensorList()) {
|
|
auto tensor_vec = ival.toTensorVector();
|
|
TORCH_CHECK(tensor_vec.size() == 1);
|
|
return tensor_vec[0];
|
|
} else if (ival.isTuple()) {
|
|
auto tuple = ival.toTuple();
|
|
auto ivalue_vec = tuple->elements();
|
|
TORCH_CHECK(ivalue_vec.size() == 1);
|
|
return ivalue_vec[0].toTensor();
|
|
} else {
|
|
CAFFE_THROW("Unknown input IValue");
|
|
}
|
|
}
|
|
|
|
Node* getNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return smodule.findNodeWithKindForTesting(kind);
|
|
}
|
|
|
|
Node* getNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
|
|
const auto symbol = c10::Symbol::fromQualString(kind);
|
|
DepthFirstGraphNodeIterator it(graph);
|
|
for (auto* node = it.next(); node != nullptr; node = it.next()) {
|
|
if (node->kind() == symbol) {
|
|
return node;
|
|
}
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
bool hasNodeWithKind(const StaticModule& smodule, const std::string& kind) {
|
|
return getNodeWithKind(smodule, kind) != nullptr;
|
|
}
|
|
|
|
bool hasNodeWithKind(std::shared_ptr<Graph>& graph, const std::string& kind) {
|
|
return getNodeWithKind(graph, kind) != nullptr;
|
|
}
|
|
|
|
std::shared_ptr<Graph> getGraphFromScript(const std::string& jit_script) {
|
|
script::Module module("module");
|
|
module.define(jit_script);
|
|
|
|
Method method = module.get_method("forward");
|
|
return module.get_method("forward").graph();
|
|
}
|
|
|
|
std::shared_ptr<Graph> getGraphFromIR(const std::string& ir) {
|
|
auto graph = std::make_shared<Graph>();
|
|
std::unordered_map<std::string, Value*> vmap;
|
|
parseIR(ir, graph.get(), vmap);
|
|
return graph;
|
|
}
|
|
|
|
void compareResultsWithJIT(
|
|
StaticRuntime& runtime,
|
|
const std::shared_ptr<Graph>& graph,
|
|
const std::vector<c10::IValue>& args,
|
|
const bool use_allclose,
|
|
const bool use_equalnan) {
|
|
GraphExecutorWrapper graph_exec(graph);
|
|
auto expected = graph_exec(args);
|
|
auto actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
compareResults(expected, actual, use_allclose, use_equalnan);
|
|
}
|
|
|
|
void testStaticRuntime(
|
|
const std::string& source,
|
|
const std::vector<IValue>& args,
|
|
const std::vector<IValue>& args2,
|
|
const bool use_allclose,
|
|
const bool use_equalnan,
|
|
const bool check_resize) {
|
|
auto test_context = makeTestContext(source);
|
|
|
|
std::vector<IValue> args_tensors, args_copy;
|
|
for (const auto& ival : args) {
|
|
if (ival.isTensor()) {
|
|
args_tensors.emplace_back(ival);
|
|
const at::Tensor& t = ival.toTensor();
|
|
args_copy.emplace_back(t.clone());
|
|
}
|
|
}
|
|
|
|
auto expect = test_context->getExpected(args);
|
|
|
|
for (bool enable_out_variant : {true, false}) {
|
|
for (bool manage_output_tensors : {true, false}) {
|
|
if (!enable_out_variant && manage_output_tensors) {
|
|
continue;
|
|
}
|
|
// run static runtime three times
|
|
// 1st run: collect allocation profiles (args)
|
|
// 2nd run: exercise memory planner and resizing with args2
|
|
// 3rd run: run with args again
|
|
StaticModuleOptions opts{
|
|
.enable_out_variant = enable_out_variant,
|
|
.optimize_memory = enable_out_variant,
|
|
.manage_output_tensors = manage_output_tensors};
|
|
auto smodule = test_context->makeStaticModule(opts);
|
|
StaticRuntime runtime(smodule);
|
|
auto actual = runtime(args, {});
|
|
if (actual.isTensor()) {
|
|
EXPECT_GE(smodule.num_nodes(), 2)
|
|
<< "If we only have one node, the output of the op we are testing is "
|
|
<< "not being managed by the memory planner! A failure here "
|
|
<< "can typically be fixed by clone()ing the output of the test script.";
|
|
}
|
|
runtime.check_for_memory_leak();
|
|
// first run
|
|
VLOG(2) << "enable_out_variant: " << enable_out_variant;
|
|
VLOG(2) << "manage_output_tensors: " << manage_output_tensors;
|
|
VLOG(2) << "args: " << args;
|
|
VLOG(2) << "args2: " << args2;
|
|
VLOG(2) << "expect: " << expect;
|
|
VLOG(2) << "actual: " << actual;
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "first run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
if (!args2.empty()) {
|
|
auto* memory_planner = runtime.get_memory_planner();
|
|
size_t managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
|
|
// Run static runtime again with inputs of a different shape.
|
|
expect = test_context->getExpected(args2);
|
|
actual = runtime(args2, {});
|
|
runtime.check_for_memory_leak();
|
|
VLOG(2) << "comparing with args2";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "second run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
|
|
size_t new_managed_bytes =
|
|
memory_planner ? memory_planner->total_managed() : 0;
|
|
if (check_resize && new_managed_bytes > 0) {
|
|
EXPECT_GT(new_managed_bytes, managed_bytes);
|
|
}
|
|
|
|
// Run static runtime again with an input of the shape observed during
|
|
// the profile run.
|
|
expect = test_context->getExpected(args);
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
// third run
|
|
VLOG(2) << "comparing third run";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "third run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
} else {
|
|
// run static runtime again to exercise the memory planner
|
|
// and allocate managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
VLOG(2) << "comparing second run with same args";
|
|
compareResults(expect, actual, use_allclose, use_equalnan);
|
|
VLOG(2) << "second run comparison done";
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
// third run to use the allocated managed tensors.
|
|
actual = runtime(args, {});
|
|
runtime.check_for_memory_leak();
|
|
if (manage_output_tensors) {
|
|
actual = IValue();
|
|
runtime.deallocateOutputTensors();
|
|
runtime.checkOutputTensorMemoryLeaks();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// make sure inputs were not modified
|
|
VLOG(2) << "Printing out input tensors";
|
|
compareTensorLists(args_tensors, args_copy, use_allclose, use_equalnan);
|
|
}
|
|
|
|
bool hasProcessedNodeWithName(
|
|
torch::jit::StaticModule& smodule,
|
|
const char* name) {
|
|
return smodule.findNodeWithKindForTesting(name) != nullptr;
|
|
}
|
|
|
|
} // namespace test
|
|
} // namespace jit
|
|
} // namespace torch
|