pytorch/test/cpp/jit/test_inliner.cpp
Mikhail Zolotukhin 2cf1183ec1 Use optimized graph in Inline (essentially, making Inline recursive now). (#26489)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26489

This basically fixes Inline(recurse=true) and makes it a default. One
reservation against running inlining recursively in the original
implementation was that we might hit a quadratic behavior, but in this
implementation it's not an issue since we're inlining only already
inlined graphs and as we recursively descend the call tree we're caching
graphs we've already optimized.

Test Plan: Imported from OSS

Differential Revision: D17485744

Pulled By: ZolotukhinM

fbshipit-source-id: 2ed7bdc69863b90a8c10a385d63f8e7c9e7b05f5
2019-09-24 00:22:29 -07:00

55 lines
1.1 KiB
C++

#include <test/cpp/jit/test_base.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/testing/file_check.h>
const auto testSource = R"JIT(
def foo1(x):
print("one")
return x
def foo2(x):
print("two")
return foo1(x)
def foo3(x):
print("three")
return foo2(x)
)JIT";
namespace torch {
namespace jit {
using namespace script;
using namespace testing;
struct InlinerGuard {
explicit InlinerGuard(bool shouldInline)
: oldState_(getInlineEverythingMode()) {
getInlineEverythingMode() = shouldInline;
}
~InlinerGuard() {
getInlineEverythingMode() = oldState_;
}
bool oldState_;
};
void testInliner() {
{
// disable automatic inlining so we can test it manually
InlinerGuard guard(/*shouldInline=*/false);
CompilationUnit cu(testSource);
auto& fn = cu.get_function("foo3");
auto g = fn.graph();
Inline(*g);
FileCheck().check_count("prim::Print", 3)->run(*g);
}
}
} // namespace jit
} // namespace torch