From 80f60f2cc098e0132ececb321a35a1d3132fe676 Mon Sep 17 00:00:00 2001 From: Raghavan Raman Date: Fri, 18 Feb 2022 10:15:48 -0800 Subject: [PATCH] [Static Runtime] Handle fallback graphs that are generated as part of the TE Fuser (#72945) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72945 ghstack-source-id: 149429754 Test Plan: ``` buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest — --gtest_filter=CpuFusion.FallbackGraph ``` Reviewed By: mikeiovine Differential Revision: D34283840 fbshipit-source-id: 868bd340a50fe691797164524f2400d07998d304 --- benchmarks/static_runtime/test_cpu_fusion.cc | 83 ++++++++++++++++++++ torch/csrc/jit/runtime/static/fusion.cpp | 13 +++ 2 files changed, 96 insertions(+) create mode 100644 benchmarks/static_runtime/test_cpu_fusion.cc diff --git a/benchmarks/static_runtime/test_cpu_fusion.cc b/benchmarks/static_runtime/test_cpu_fusion.cc new file mode 100644 index 00000000000..f482b87957c --- /dev/null +++ b/benchmarks/static_runtime/test_cpu_fusion.cc @@ -0,0 +1,83 @@ +#include +#include +#include + +#include "test_utils.h" + +using namespace torch; +using namespace torch::jit; +using namespace torch::jit::test; + +TEST(CpuFusion, Simple) { + const auto simple_script = R"JIT( + def forward(self, a, b): + return (a + b).relu().tanh() + )JIT"; + + Module m("module"); + m.define(simple_script); + + StaticModuleOptions opts; // start with the defaults. + opts.enable_tensorexpr_fusion = true; + + auto input1 = at::randn({2, 3}); + auto input2 = at::ones({2, 3}); + + auto smodule = StaticModule(m, /* is_frozen */ false, opts, {input1, input2}); + StaticRuntime runtime(smodule); + + // Test with sample inputs + { + auto actual = runtime({input1, input2}, {}); + auto expect = at::tanh(at::relu(input1 + input2)); + EXPECT_TRUE(at::allclose(expect, actual.toTensor())); + } + + // Test with different inputs + { + auto new_input1 = at::randn({5, 14}); + auto new_input2 = at::randn({5, 14}); + auto actual = runtime({new_input1, new_input2}, {}); + auto expect = at::tanh(at::relu(new_input1 + new_input2)); + EXPECT_TRUE(at::allclose(expect, actual.toTensor())); + } +} + +TEST(CpuFusion, FallbackGraph) { + const auto simple_script = R"JIT( + def forward(self, a, b): + return (a + b).relu().tanh() + )JIT"; + + Module m("module"); + m.define(simple_script); + + StaticModuleOptions opts; // start with the defaults. + opts.enable_tensorexpr_fusion = true; + + auto sample_input1 = at::randn({2, 3}); + auto sample_input2 = at::ones({2, 3}); + auto smodule = StaticModule( + m, /* is_frozen */ false, opts, {sample_input1, sample_input2}); + + StaticRuntime runtime(smodule); + + // The sample inputs above were contiguous. Now, use a strided input + // to trigger running the fallback graph. + { + auto input1 = at::narrow(at::randn({2, 6}), 1, 0, 3); + auto input2 = at::ones({2, 3}); + auto expect = at::tanh(at::relu(input1 + input2)); + auto actual = runtime({input1, input2}, {}); + EXPECT_TRUE(at::allclose(expect, actual.toTensor())); + } + + // Test with strided inputs of different size. + { + auto input1 = at::narrow(at::randn({10, 30}), 1, 0, 25); + auto input2 = at::randn({10, 25}); + auto expect = at::tanh(at::relu(input1 + input2)); + auto actual = runtime({input1, input2}, {}); + EXPECT_TRUE(at::allclose(expect, actual.toTensor())); + } +} diff --git a/torch/csrc/jit/runtime/static/fusion.cpp b/torch/csrc/jit/runtime/static/fusion.cpp index 556d1bc0b91..038e03c6f2e 100644 --- a/torch/csrc/jit/runtime/static/fusion.cpp +++ b/torch/csrc/jit/runtime/static/fusion.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -322,6 +323,17 @@ void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) { inlineSmallFusionGroups(block, min_size); } +void inlineFallbackGraphs(std::shared_ptr graph) { + DepthFirstGraphNodeIterator it(graph); + + Node* n = nullptr; + while ((n = it.next()) != nullptr) { + if (n->kind() == prim::FallbackGraph) { + SubgraphUtils::unmergeSubgraph(n); + } + } +} + void performTensorExprFusion( std::shared_ptr graph, std::vector sample_inputs) { @@ -335,6 +347,7 @@ void performTensorExprFusion( /*min_group_size*/ 2, /*add_composed_op*/ false, /*fuse_to_dynamic_shapes*/ true); + inlineFallbackGraphs(traced_graph); graph->block()->clear(); graph->block()->cloneFrom(traced_graph->block(), nullptr); GRAPH_DUMP("Graph after fusion: ", graph);