mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26489 This basically fixes Inline(recurse=true) and makes it a default. One reservation against running inlining recursively in the original implementation was that we might hit a quadratic behavior, but in this implementation it's not an issue since we're inlining only already inlined graphs and as we recursively descend the call tree we're caching graphs we've already optimized. Test Plan: Imported from OSS Differential Revision: D17485744 Pulled By: ZolotukhinM fbshipit-source-id: 2ed7bdc69863b90a8c10a385d63f8e7c9e7b05f5
55 lines
1.1 KiB
C++
55 lines
1.1 KiB
C++
#include <test/cpp/jit/test_base.h>
|
|
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
#include <torch/csrc/jit/script/compilation_unit.h>
|
|
#include <torch/csrc/jit/script/module.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
|
|
const auto testSource = R"JIT(
|
|
def foo1(x):
|
|
print("one")
|
|
return x
|
|
|
|
def foo2(x):
|
|
print("two")
|
|
return foo1(x)
|
|
|
|
def foo3(x):
|
|
print("three")
|
|
return foo2(x)
|
|
)JIT";
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
using namespace script;
|
|
using namespace testing;
|
|
|
|
struct InlinerGuard {
|
|
explicit InlinerGuard(bool shouldInline)
|
|
: oldState_(getInlineEverythingMode()) {
|
|
getInlineEverythingMode() = shouldInline;
|
|
}
|
|
|
|
~InlinerGuard() {
|
|
getInlineEverythingMode() = oldState_;
|
|
}
|
|
|
|
bool oldState_;
|
|
};
|
|
|
|
void testInliner() {
|
|
{
|
|
// disable automatic inlining so we can test it manually
|
|
InlinerGuard guard(/*shouldInline=*/false);
|
|
|
|
CompilationUnit cu(testSource);
|
|
auto& fn = cu.get_function("foo3");
|
|
|
|
auto g = fn.graph();
|
|
Inline(*g);
|
|
FileCheck().check_count("prim::Print", 3)->run(*g);
|
|
}
|
|
}
|
|
} // namespace jit
|
|
} // namespace torch
|