mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Handle fallback graphs that are generated as part of the TE Fuser (#72945)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72945
ghstack-source-id: 149429754
Test Plan:
```
buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest — --gtest_filter=CpuFusion.FallbackGraph
```
Reviewed By: mikeiovine
Differential Revision: D34283840
fbshipit-source-id: 868bd340a50fe691797164524f2400d07998d304
(cherry picked from commit 80f60f2cc0)
This commit is contained in:
parent
87f882b056
commit
02afdd54b9
83
benchmarks/static_runtime/test_cpu_fusion.cc
Normal file
83
benchmarks/static_runtime/test_cpu_fusion.cc
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
#include <gtest/gtest.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
using namespace torch;
|
||||
using namespace torch::jit;
|
||||
using namespace torch::jit::test;
|
||||
|
||||
TEST(CpuFusion, Simple) {
|
||||
const auto simple_script = R"JIT(
|
||||
def forward(self, a, b):
|
||||
return (a + b).relu().tanh()
|
||||
)JIT";
|
||||
|
||||
Module m("module");
|
||||
m.define(simple_script);
|
||||
|
||||
StaticModuleOptions opts; // start with the defaults.
|
||||
opts.enable_tensorexpr_fusion = true;
|
||||
|
||||
auto input1 = at::randn({2, 3});
|
||||
auto input2 = at::ones({2, 3});
|
||||
|
||||
auto smodule = StaticModule(m, /* is_frozen */ false, opts, {input1, input2});
|
||||
StaticRuntime runtime(smodule);
|
||||
|
||||
// Test with sample inputs
|
||||
{
|
||||
auto actual = runtime({input1, input2}, {});
|
||||
auto expect = at::tanh(at::relu(input1 + input2));
|
||||
EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
|
||||
}
|
||||
|
||||
// Test with different inputs
|
||||
{
|
||||
auto new_input1 = at::randn({5, 14});
|
||||
auto new_input2 = at::randn({5, 14});
|
||||
auto actual = runtime({new_input1, new_input2}, {});
|
||||
auto expect = at::tanh(at::relu(new_input1 + new_input2));
|
||||
EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuFusion, FallbackGraph) {
|
||||
const auto simple_script = R"JIT(
|
||||
def forward(self, a, b):
|
||||
return (a + b).relu().tanh()
|
||||
)JIT";
|
||||
|
||||
Module m("module");
|
||||
m.define(simple_script);
|
||||
|
||||
StaticModuleOptions opts; // start with the defaults.
|
||||
opts.enable_tensorexpr_fusion = true;
|
||||
|
||||
auto sample_input1 = at::randn({2, 3});
|
||||
auto sample_input2 = at::ones({2, 3});
|
||||
auto smodule = StaticModule(
|
||||
m, /* is_frozen */ false, opts, {sample_input1, sample_input2});
|
||||
|
||||
StaticRuntime runtime(smodule);
|
||||
|
||||
// The sample inputs above were contiguous. Now, use a strided input
|
||||
// to trigger running the fallback graph.
|
||||
{
|
||||
auto input1 = at::narrow(at::randn({2, 6}), 1, 0, 3);
|
||||
auto input2 = at::ones({2, 3});
|
||||
auto expect = at::tanh(at::relu(input1 + input2));
|
||||
auto actual = runtime({input1, input2}, {});
|
||||
EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
|
||||
}
|
||||
|
||||
// Test with strided inputs of different size.
|
||||
{
|
||||
auto input1 = at::narrow(at::randn({10, 30}), 1, 0, 25);
|
||||
auto input2 = at::randn({10, 25});
|
||||
auto expect = at::tanh(at::relu(input1 + input2));
|
||||
auto actual = runtime({input1, input2}, {});
|
||||
EXPECT_TRUE(at::allclose(expect, actual.toTensor()));
|
||||
}
|
||||
}
|
||||
|
|
@ -11,6 +11,7 @@
|
|||
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||
#include <torch/csrc/jit/runtime/custom_operator.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
#include <torch/csrc/jit/runtime/jit_trace.h>
|
||||
#include <torch/csrc/jit/runtime/static/impl.h>
|
||||
#include <torch/csrc/jit/runtime/static/ops.h>
|
||||
|
|
@ -322,6 +323,17 @@ void createFusionGroups(Block* block, AliasDb* aliasDb, size_t min_size) {
|
|||
inlineSmallFusionGroups(block, min_size);
|
||||
}
|
||||
|
||||
void inlineFallbackGraphs(std::shared_ptr<Graph> graph) {
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
|
||||
Node* n = nullptr;
|
||||
while ((n = it.next()) != nullptr) {
|
||||
if (n->kind() == prim::FallbackGraph) {
|
||||
SubgraphUtils::unmergeSubgraph(n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void performTensorExprFusion(
|
||||
std::shared_ptr<Graph> graph,
|
||||
std::vector<IValue> sample_inputs) {
|
||||
|
|
@ -335,6 +347,7 @@ void performTensorExprFusion(
|
|||
/*min_group_size*/ 2,
|
||||
/*add_composed_op*/ false,
|
||||
/*fuse_to_dynamic_shapes*/ true);
|
||||
inlineFallbackGraphs(traced_graph);
|
||||
graph->block()->clear();
|
||||
graph->block()->cloneFrom(traced_graph->block(), nullptr);
|
||||
GRAPH_DUMP("Graph after fusion: ", graph);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user