mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Initial NNC Dynamic Shapes Flow (#66136)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66136 FOR REVIEWERS: this is ready to review, test failures comes from somewhere else in stack.. Takes in a TensorExprGraph of static shapes and generalizes the input shapes to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise dimensions with the same value will be bucketed to the same symbolic shape. E.g. `Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)` From there, runs symbolic shape inference on the graph, and creates a versioning if in the graph with prim::TensorExprDynamicGuard checking if the inputs at runtime match the Generalized Symbolic Shapes that are inputs to the TE Kernel. The computate to calculate all symbolic dimensions is inlined in to the if block with the TE Kernel. All Sym Dim Value* are appended to the end of the TE Kernel Graph/Node inputs, and the Node is augmented with a integer list attr `symbolic_shape_inputs` that gives the mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in `test_shape_analysis` Returns True on Success, False on Failure, can fail if shape propagation fails to propagate # of dims or if complete shapes on inputs not set. Example transformation ``` graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu), %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu), %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)): %3 : Tensor = prim::TensorExprGroup_0(%x_inp, %y_inp, %z_inp) return () with prim::TensorExprGroup_0 = graph(%x.1 : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu), %y.1 : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu), %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)): %3 : int = prim::Constant[value=0]() %4 : Tensor = aten::tanh(%x.1) %5 : Tensor = aten::erf(%4) %6 : Tensor = aten::relu(%y.1) %7 : Tensor[] = prim::ListConstruct(%5, %6) %8 : Tensor = aten::cat(%7, %3) %9 : Tensor = aten::hardswish(%8) %10 : Tensor = aten::mul(%9, %z) return (%9) ``` -> ``` graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu), %y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu), %z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)): %4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp) %5 : Tensor = prim::If(%4) block0(): %15 : int[] = aten::size(%x_inp) %16 : int[] = aten::size(%y_inp) %17 : int = prim::Constant[value=1]() %18 : int = prim::Constant[value=0]() %elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10 %elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10 %elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10 %cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29 %3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3) -> (%3) block1(): %14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp) -> (%14) return () with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), %y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), %z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu), %SS_5 : int, %SS_4 : int, %SS_3 : int, %SS_2 : int): %3 : int = prim::Constant[value=0]() %4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1) %5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4) %6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1) %7 : Tensor[] = prim::ListConstruct(%5, %6) %8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3) %9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8) %10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z) return (%9) ``` Test Plan: Imported from OSS Reviewed By: navahgar Differential Revision: D31732414 Pulled By: eellison fbshipit-source-id: 290a94a667c20467717202a43c60e4f9ca4c00e2
This commit is contained in:
parent
b4db5174fe
commit
de4fe7a38c
|
|
@ -463,6 +463,7 @@ namespace c10 {
|
||||||
_(attr, df_output_vjps) \
|
_(attr, df_output_vjps) \
|
||||||
_(attr, axes) \
|
_(attr, axes) \
|
||||||
_(attr, axis) \
|
_(attr, axis) \
|
||||||
|
_(attr, symbolic_shape_inputs) \
|
||||||
_(attr, broadcast) \
|
_(attr, broadcast) \
|
||||||
_(attr, direction) \
|
_(attr, direction) \
|
||||||
_(attr, ends) \
|
_(attr, ends) \
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ set(JIT_TEST_SRCS
|
||||||
${JIT_TEST_ROOT}/test_union.cpp
|
${JIT_TEST_ROOT}/test_union.cpp
|
||||||
${JIT_TEST_ROOT}/test_utils.cpp
|
${JIT_TEST_ROOT}/test_utils.cpp
|
||||||
${JIT_TEST_ROOT}/test_script_profile.cpp
|
${JIT_TEST_ROOT}/test_script_profile.cpp
|
||||||
|
${JIT_TEST_ROOT}/test_shape_analysis.cpp
|
||||||
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
|
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
250
test/cpp/jit/test_shape_analysis.cpp
Normal file
250
test/cpp/jit/test_shape_analysis.cpp
Normal file
|
|
@ -0,0 +1,250 @@
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <ATen/core/interned_strings.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <test/cpp/jit/test_utils.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir_views.h>
|
||||||
|
#include <torch/csrc/jit/ir/irparser.h>
|
||||||
|
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||||
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||||
|
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||||
|
#include <torch/csrc/jit/runtime/interpreter.h>
|
||||||
|
#include <torch/cuda.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <torch/csrc/jit/testing/file_check.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Node* findNode(std::shared_ptr<Graph>& g, Symbol k) {
|
||||||
|
DepthFirstGraphNodeIterator graph_it(g);
|
||||||
|
for (auto node = graph_it.next(); node != nullptr; node = graph_it.next()) {
|
||||||
|
if (node->kind() == k) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "Couldn't find node");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
TEST(ShapeAnalysisTest, DynamicShapesFusion) {
|
||||||
|
// Test Generalizing shapes to symbolic dimensions, guarding those symbolic
|
||||||
|
// dimensions and passing in runtime computed symbolic dimensions via inlined
|
||||||
|
// shape functions
|
||||||
|
std::shared_ptr<Graph> subgraph = std::make_shared<Graph>();
|
||||||
|
const auto graph_string = R"IR(
|
||||||
|
graph(%x.1 : Tensor, %y.1 : Tensor, %z: Tensor):
|
||||||
|
%11 : int = prim::Constant[value=0]()
|
||||||
|
%3 : Tensor = aten::tanh(%x.1)
|
||||||
|
%out1.1 : Tensor = aten::erf(%3)
|
||||||
|
%out2.1 : Tensor = aten::relu(%y.1)
|
||||||
|
%10 : Tensor[] = prim::ListConstruct(%out1.1, %out2.1)
|
||||||
|
%25 : Tensor = aten::cat(%10, %11)
|
||||||
|
%28 : Tensor = aten::hardswish(%25)
|
||||||
|
%29 : Tensor = aten::mul(%28, %z)
|
||||||
|
return (%28))IR";
|
||||||
|
torch::jit::parseIR(graph_string, subgraph.get());
|
||||||
|
|
||||||
|
/*
|
||||||
|
set up fused TensorExprGroup
|
||||||
|
*/
|
||||||
|
|
||||||
|
std::shared_ptr<Graph> g = std::make_shared<Graph>();
|
||||||
|
auto x_inp = g->addInput("x_inp");
|
||||||
|
auto y_inp = g->addInput("y_inp");
|
||||||
|
auto z_inp = g->addInput("z_inp");
|
||||||
|
auto x_type = TensorType::create(at::rand({10, 5}));
|
||||||
|
auto y_type = TensorType::create(at::rand({4, 5}));
|
||||||
|
auto z_type = TensorType::create(at::rand({1, 1}));
|
||||||
|
x_inp->setType(x_type);
|
||||||
|
y_inp->setType(y_type);
|
||||||
|
z_inp->setType(z_type);
|
||||||
|
subgraph->inputs().at(0)->setType(x_type);
|
||||||
|
subgraph->inputs().at(1)->setType(y_type);
|
||||||
|
subgraph->inputs().at(2)->setType(z_type);
|
||||||
|
auto output = g->insertNode(g->create(prim::TensorExprGroup))->output();
|
||||||
|
output->node()->addInput(x_inp);
|
||||||
|
output->node()->addInput(y_inp);
|
||||||
|
output->node()->addInput(z_inp);
|
||||||
|
output->node()->g_(attr::Subgraph, subgraph);
|
||||||
|
|
||||||
|
auto success = GenerateGuard(output->node());
|
||||||
|
TORCH_INTERNAL_ASSERT(success);
|
||||||
|
testing::FileCheck()
|
||||||
|
.check("TensorExprDynamicGuard")
|
||||||
|
->check_next("prim::If")
|
||||||
|
->check("aten::add")
|
||||||
|
->check("TensorExprGroup")
|
||||||
|
->check_same("symbolic_shape_inputs")
|
||||||
|
->check("block1")
|
||||||
|
->check("FallbackGraph")
|
||||||
|
->run(*g);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
/* Graph Should Look Something like: (note: strides not yet handled)
|
||||||
|
graph(%x_inp : Float(10, 5, strides=[5, 1], requires_grad=0, device=cpu),
|
||||||
|
%y_inp : Float(4, 5, strides=[5, 1], requires_grad=0, device=cpu),
|
||||||
|
%z_inp : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)):
|
||||||
|
%4 : bool = prim::TensorExprDynamicGuard[types=[Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu), Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu)]](%x_inp, %y_inp, %z_inp)
|
||||||
|
%5 : Tensor = prim::If(%4)
|
||||||
|
block0():
|
||||||
|
%15 : int[] = aten::size(%x_inp)
|
||||||
|
%16 : int[] = aten::size(%y_inp)
|
||||||
|
%17 : int = prim::Constant[value=1]()
|
||||||
|
%18 : int = prim::Constant[value=0]()
|
||||||
|
%elem.3 : int = aten::__getitem__(%15, %18) # <string>:40:10
|
||||||
|
%elem.5 : int = aten::__getitem__(%15, %17) # <string>:40:10
|
||||||
|
%elem.11 : int = aten::__getitem__(%16, %18) # <string>:40:10
|
||||||
|
%cat_dim_size.48 : int = aten::add(%elem.3, %elem.11) # <string>:321:29
|
||||||
|
%3 : Tensor = prim::TensorExprGroup_0[symbolic_shape_inputs=[-5, -4, -3, -2]](%x_inp, %y_inp, %z_inp, %cat_dim_size.48, %elem.11, %elem.5, %elem.3)
|
||||||
|
-> (%3)
|
||||||
|
block1():
|
||||||
|
%14 : Tensor = prim::FallbackGraph_1(%x_inp, %y_inp, %z_inp)
|
||||||
|
-> (%14)
|
||||||
|
return ()
|
||||||
|
with prim::TensorExprGroup_0 = graph(%x.1 : Float(SS(-2), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
|
||||||
|
%y.1 : Float(SS(-4), SS(-3), strides=[5, 1], requires_grad=0, device=cpu),
|
||||||
|
%z : Float(1, 1, strides=[1, 1], requires_grad=0, device=cpu),
|
||||||
|
%SS_5 : int,
|
||||||
|
%SS_4 : int,
|
||||||
|
%SS_3 : int,
|
||||||
|
%SS_2 : int):
|
||||||
|
%3 : int = prim::Constant[value=0]()
|
||||||
|
%4 : Tensor(SS(-2), SS(-3)) = aten::tanh(%x.1)
|
||||||
|
%5 : Tensor(SS(-2), SS(-3)) = aten::erf(%4)
|
||||||
|
%6 : Tensor(SS(-4), SS(-3)) = aten::relu(%y.1)
|
||||||
|
%7 : Tensor[] = prim::ListConstruct(%5, %6)
|
||||||
|
%8 : Tensor(SS(-5), SS(-3)) = aten::cat(%7, %3)
|
||||||
|
%9 : Tensor(SS(-5), SS(-3)) = aten::hardswish(%8)
|
||||||
|
%10 : Tensor(SS(-5), SS(-3)) = aten::mul(%9, %z)
|
||||||
|
return (%9)
|
||||||
|
*/
|
||||||
|
// clang-format on
|
||||||
|
|
||||||
|
DepthFirstGraphNodeIterator graph_it(g);
|
||||||
|
Node* te_group = findNode(g, prim::TensorExprGroup);
|
||||||
|
|
||||||
|
/*
|
||||||
|
Test that input to the kernel - (10, 5), (4, 5), (1, 1) - are correctly
|
||||||
|
generalized to sym dimensions, and that the output - (10 + 4, 5)
|
||||||
|
correctly preserves non-catted dim as sym shape and catted dim as new sym
|
||||||
|
shape
|
||||||
|
*/
|
||||||
|
|
||||||
|
auto tensorexpr_graph = te_group->g(attr::Subgraph);
|
||||||
|
auto inp1 = tensorexpr_graph->inputs().at(0)->type()->expect<TensorType>();
|
||||||
|
auto inp2 = tensorexpr_graph->inputs().at(1)->type()->expect<TensorType>();
|
||||||
|
auto inp3 = tensorexpr_graph->inputs().at(2)->type()->expect<TensorType>();
|
||||||
|
auto out = tensorexpr_graph->outputs().at(0)->type()->expect<TensorType>();
|
||||||
|
|
||||||
|
// 1 dims are preserved
|
||||||
|
auto inp3_sizes = inp3->sizes().concrete_sizes();
|
||||||
|
TORCH_INTERNAL_ASSERT(inp3_sizes);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
inp3_sizes->size() == 2 && inp3_sizes->at(0) == 1 &&
|
||||||
|
inp3_sizes->at(1) == 1);
|
||||||
|
|
||||||
|
// 5 made into sym shape
|
||||||
|
ASSERT_EQ(
|
||||||
|
inp1->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
|
||||||
|
ASSERT_EQ(
|
||||||
|
out->symbolic_sizes()[1].value(), inp2->symbolic_sizes()[1].value());
|
||||||
|
|
||||||
|
// 4, 10, 14 are different sym shapes
|
||||||
|
ASSERT_NE(
|
||||||
|
inp1->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
|
||||||
|
ASSERT_NE(
|
||||||
|
out->symbolic_sizes()[0].value(), inp1->symbolic_sizes()[0].value());
|
||||||
|
ASSERT_NE(
|
||||||
|
out->symbolic_sizes()[0].value(), inp2->symbolic_sizes()[0].value());
|
||||||
|
|
||||||
|
/*
|
||||||
|
Test guard behaves correctly at runtime and symbolic shapes are computed
|
||||||
|
correctly. As we don't have have TE Kernel support for dynamic shapes we're
|
||||||
|
going to return all of the computed runtime symbolic dimensions as outputs
|
||||||
|
of the graph on guard success, and return None on guard failure
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Setting up guard to return sym shapes on guard success and None on failure
|
||||||
|
Node* if_node = findNode(g, prim::If);
|
||||||
|
IfView if_v(if_node);
|
||||||
|
if_node->eraseOutput(0);
|
||||||
|
if_v.thenBlock()->eraseOutput(0);
|
||||||
|
if_v.elseBlock()->eraseOutput(0);
|
||||||
|
WithInsertPoint guard(if_node);
|
||||||
|
auto none_val = g->insertConstant(IValue());
|
||||||
|
|
||||||
|
auto sym_shapes = te_group->is(Symbol::attr("symbolic_shape_inputs"));
|
||||||
|
auto offset = te_group->inputs().size() - sym_shapes.size();
|
||||||
|
for (size_t i = 0; i < sym_shapes.size(); ++i) {
|
||||||
|
if_v.thenBlock()->insertOutput(i, te_group->inputs().at(offset + i));
|
||||||
|
if_v.elseBlock()->insertOutput(i, none_val);
|
||||||
|
if_node->insertOutput(i)->setType(OptionalType::create(IntType::get()));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto new_outputs = g->createTuple(if_node->outputs())->insertAfter(if_node);
|
||||||
|
|
||||||
|
g->registerOutput(new_outputs->output());
|
||||||
|
te_group->destroy();
|
||||||
|
findNode(g, prim::FallbackGraph)->destroy();
|
||||||
|
|
||||||
|
// Testing bad inputs
|
||||||
|
|
||||||
|
auto first_inp = at::rand({2, 5});
|
||||||
|
std::vector<std::vector<at::Tensor>> second_inps = {
|
||||||
|
{at::rand({3, 4}), at::rand({1, 1})}, // sym shape mismatch
|
||||||
|
{at::rand({5, 2}).transpose(0, 1), at::rand({1, 1})}, // discontiguous
|
||||||
|
{at::zeros({2, 5}).to(at::ScalarType::Int),
|
||||||
|
at::rand({1, 1})}, // wrong dtype
|
||||||
|
{at::rand({2, 5, 1}), at::rand({1, 1})}, // wrong # dims
|
||||||
|
{at::rand({2, 5}).requires_grad_(true),
|
||||||
|
at::rand({1, 1})}, // requires grad
|
||||||
|
{at::rand({2, 5}), at::rand({1, 12})}, // concrete dim mismatch (1)
|
||||||
|
};
|
||||||
|
if (torch::cuda::is_available()) {
|
||||||
|
second_inps.push_back({at::rand({2, 5}).cuda(), at::rand({1, 1})});
|
||||||
|
}
|
||||||
|
for (const auto& last_inps : second_inps) {
|
||||||
|
// todo - reusing interpreter across iters gave error
|
||||||
|
Code code(g, "");
|
||||||
|
InterpreterState interp(code);
|
||||||
|
auto stack = createStack({at::rand({2, 5}), last_inps[0], last_inps[1]});
|
||||||
|
interp.run(stack);
|
||||||
|
TORCH_INTERNAL_ASSERT(pop(stack).toTuple()->elements().at(0).isNone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test good inputs
|
||||||
|
Code code(g, "");
|
||||||
|
InterpreterState interp(code);
|
||||||
|
std::vector<at::Tensor> inps = {
|
||||||
|
at::rand({2, 5}), at::rand({4, 5}), at::rand({1, 1})};
|
||||||
|
Stack stack(inps.begin(), inps.end());
|
||||||
|
interp.run(stack);
|
||||||
|
auto tuple = pop(stack).toTuple();
|
||||||
|
TORCH_INTERNAL_ASSERT(tuple->elements().at(0).isInt());
|
||||||
|
|
||||||
|
// Testing that the sym shape calculation was correct
|
||||||
|
for (size_t i = 0; i < sym_shapes.size(); ++i) {
|
||||||
|
auto sym_shape = sym_shapes[i];
|
||||||
|
auto computed_value = tuple->elements().at(i).toInt();
|
||||||
|
if (sym_shape == inp1->symbolic_sizes().at(0).value()) {
|
||||||
|
ASSERT_EQ(computed_value, 2);
|
||||||
|
} else if (sym_shape == inp1->symbolic_sizes().at(1).value()) {
|
||||||
|
ASSERT_EQ(computed_value, 5);
|
||||||
|
} else if (sym_shape == inp2->symbolic_sizes().at(0).value()) {
|
||||||
|
ASSERT_EQ(computed_value, 4);
|
||||||
|
} else if (sym_shape == out->symbolic_sizes().at(0).value()) {
|
||||||
|
ASSERT_EQ(computed_value, 6);
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -254,6 +254,7 @@ core_sources_full_mobile_no_backend_interface = [
|
||||||
"torch/csrc/jit/passes/shape_analysis.cpp",
|
"torch/csrc/jit/passes/shape_analysis.cpp",
|
||||||
"torch/csrc/jit/passes/integer_value_refinement.cpp",
|
"torch/csrc/jit/passes/integer_value_refinement.cpp",
|
||||||
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
|
"torch/csrc/jit/passes/symbolic_shape_analysis.cpp",
|
||||||
|
"torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp",
|
||||||
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
"torch/csrc/jit/passes/specialize_autogradzero.cpp",
|
||||||
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
|
"torch/csrc/jit/passes/update_differentiable_graph_requires_grad.cpp",
|
||||||
"torch/csrc/jit/passes/variadic_ops.cpp",
|
"torch/csrc/jit/passes/variadic_ops.cpp",
|
||||||
|
|
|
||||||
|
|
@ -826,7 +826,8 @@ struct SymbolicShapeGraphAnalyzer {
|
||||||
new_list_output;
|
new_list_output;
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
new_list_output->node()->kind() == prim::ListConstruct);
|
new_list_output->node()->kind() == prim::ListConstruct ||
|
||||||
|
new_list_output->node()->kind() == prim::Constant);
|
||||||
TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
|
TORCH_INTERNAL_ASSERT(!new_list_output->node()->hasUses());
|
||||||
|
|
||||||
auto symbolic_sizes =
|
auto symbolic_sizes =
|
||||||
|
|
|
||||||
360
torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
Normal file
360
torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp
Normal file
|
|
@ -0,0 +1,360 @@
|
||||||
|
#include <ATen/core/functional.h>
|
||||||
|
#include <c10/core/MemoryFormat.h>
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/jit_log.h>
|
||||||
|
#include <torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h>
|
||||||
|
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
|
||||||
|
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
|
||||||
|
#include <torch/csrc/jit/runtime/register_ops_utils.h>
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
// Inserts the Compute for Each Symbolic Shape in the TensorExpr Graph
|
||||||
|
// and returns back a map from Symbolic Shape Value to its runtime Value *
|
||||||
|
std::map<int64_t, Value*> InsertSymbolicShapesCompute(
|
||||||
|
const ShapeComputeGraphMapping& shape_mapping,
|
||||||
|
Node* tensorexpr_graph) {
|
||||||
|
WithInsertPoint guard(tensorexpr_graph);
|
||||||
|
auto enclosing_graph = tensorexpr_graph->owningGraph();
|
||||||
|
|
||||||
|
std::map<Value*, Value*> shape_graph_input_to_enclosing_graph_value;
|
||||||
|
for (const auto& pair :
|
||||||
|
shape_mapping.enclosing_graph_value_to_shape_graph_input_) {
|
||||||
|
shape_graph_input_to_enclosing_graph_value[pair.second] = pair.first;
|
||||||
|
}
|
||||||
|
std::vector<Value*> shape_compute_graph_inputs;
|
||||||
|
for (Value* shape_graph_input :
|
||||||
|
shape_mapping.partial_eval_shape_graph->inputs()) {
|
||||||
|
auto enclosing_graph_input =
|
||||||
|
shape_graph_input_to_enclosing_graph_value.find(shape_graph_input);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
enclosing_graph_input !=
|
||||||
|
shape_graph_input_to_enclosing_graph_value.end());
|
||||||
|
if (*enclosing_graph_input->second->type() == *shape_graph_input->type()) {
|
||||||
|
shape_compute_graph_inputs.push_back(tensorexpr_graph->inputs().at(
|
||||||
|
enclosing_graph_input->second->offset()));
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
enclosing_graph_input->second->type()->cast<TensorType>() &&
|
||||||
|
shape_graph_input->type()->isSubtypeOf(ListType::ofInts()));
|
||||||
|
shape_compute_graph_inputs.push_back(enclosing_graph->insert(
|
||||||
|
aten::size,
|
||||||
|
{tensorexpr_graph->inputs().at(
|
||||||
|
enclosing_graph_input->second->offset())}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto sym_shape_values = insertGraph(
|
||||||
|
*enclosing_graph,
|
||||||
|
*shape_mapping.partial_eval_shape_graph,
|
||||||
|
shape_compute_graph_inputs);
|
||||||
|
std::map<int64_t, Value*> sym_shape_to_enclosing_graph_value;
|
||||||
|
for (size_t i = 0;
|
||||||
|
i < shape_mapping.partial_eval_shape_graph->outputs().size();
|
||||||
|
++i) {
|
||||||
|
Value* output = shape_mapping.partial_eval_shape_graph->outputs().at(i);
|
||||||
|
auto sym_shape =
|
||||||
|
shape_mapping.graph_output_to_symbolic_shape_dim_.find(output);
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
sym_shape != shape_mapping.graph_output_to_symbolic_shape_dim_.end());
|
||||||
|
sym_shape_to_enclosing_graph_value[sym_shape->second] = sym_shape_values[i];
|
||||||
|
}
|
||||||
|
return sym_shape_to_enclosing_graph_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertDynamicShapesGuard(
|
||||||
|
const ShapeComputeGraphMapping& shape_mapping,
|
||||||
|
Node* guarded_node);
|
||||||
|
|
||||||
|
// Generalize Complete Shapes inputs to Symbolic Shapes.
|
||||||
|
// Dimensions of value 1 will be preserved, otherwise
|
||||||
|
// dimensions with the same value will be bucketed to the same
|
||||||
|
// symbolic shape.
|
||||||
|
// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
|
||||||
|
bool TryGeneralizeInputDimensionsToSymbolicShapes(
|
||||||
|
std::shared_ptr<Graph> tensorexpr_graph) {
|
||||||
|
std::map<size_t, int64_t> shape_to_sym_shape;
|
||||||
|
for (Value* v : tensorexpr_graph->inputs()) {
|
||||||
|
if (!v->type()->cast<TensorType>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (!v->type()->expect<TensorType>()->sizes().isComplete()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto tt = v->type()->expect<TensorType>();
|
||||||
|
std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
|
||||||
|
auto new_sizes = c10::fmap(shape_vec, [&](const at::ShapeSymbol& shape) {
|
||||||
|
auto value = shape.value();
|
||||||
|
TORCH_INTERNAL_ASSERT(value >= 0, "Expected complete tensor");
|
||||||
|
if (value == 1) {
|
||||||
|
return value;
|
||||||
|
} else if (shape_to_sym_shape.count(static_cast<size_t>(value))) {
|
||||||
|
return shape_to_sym_shape[value];
|
||||||
|
} else {
|
||||||
|
auto new_shape_symbol = at::ShapeSymbol::newSymbol().value();
|
||||||
|
shape_to_sym_shape[static_cast<size_t>(value)] = new_shape_symbol;
|
||||||
|
return new_shape_symbol;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
v->setType(tt->withSymbolicShapes(c10::SymbolicShape(new_sizes)));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GenerateGuard(Node* tensorexpr_graph_node) {
|
||||||
|
auto tensorexpr_graph = SubgraphUtils::getSubgraph(tensorexpr_graph_node);
|
||||||
|
|
||||||
|
// Generalize Inputs
|
||||||
|
if (!TryGeneralizeInputDimensionsToSymbolicShapes(tensorexpr_graph)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try To Propagate Shapes
|
||||||
|
auto maybe_shape_compute_mapping =
|
||||||
|
PropagateShapesAndBuildLargeShapeComputeGraph(
|
||||||
|
tensorexpr_graph,
|
||||||
|
*tensorexpr_graph->nodes().begin(),
|
||||||
|
*tensorexpr_graph->nodes().end());
|
||||||
|
if (!maybe_shape_compute_mapping) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert Guard
|
||||||
|
insertDynamicShapesGuard(*maybe_shape_compute_mapping, tensorexpr_graph_node);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: share more logic with tensorexpr_fuser ?
|
||||||
|
void insertDynamicShapesGuard(
|
||||||
|
const ShapeComputeGraphMapping& shape_mapping,
|
||||||
|
Node* guarded_node) {
|
||||||
|
GRAPH_DEBUG(
|
||||||
|
"Inserting a prim::TensorExprDynamicGuard guard for a node",
|
||||||
|
*guarded_node);
|
||||||
|
auto subgraph = SubgraphUtils::getSubgraph(guarded_node);
|
||||||
|
|
||||||
|
// Fixup types of the subgraph inputs
|
||||||
|
std::vector<Value*> inputs_to_check;
|
||||||
|
std::vector<TypePtr> guard_types;
|
||||||
|
for (const auto i : c10::irange(guarded_node->inputs().size())) {
|
||||||
|
Value* node_input = guarded_node->inputs().at(i);
|
||||||
|
// We only check inputs of the guarded nodes
|
||||||
|
if (!node_input->type()->cast<TensorType>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
inputs_to_check.push_back(node_input);
|
||||||
|
guard_types.push_back(
|
||||||
|
subgraph->inputs().at(i)->type()->expect<TensorType>());
|
||||||
|
}
|
||||||
|
TORCH_INTERNAL_ASSERT(inputs_to_check.size());
|
||||||
|
|
||||||
|
// prim::TensorExprDynamicGuard nodes look like the following:
|
||||||
|
// %types_match : bool = prim::TypeCheck[attr:types](%inp1 : Tensor, %inp2 :
|
||||||
|
// Tensor)
|
||||||
|
// The input tensors are checked against the expected types on attr::types
|
||||||
|
// Omitting refining the input Tensors for now because they are not actually
|
||||||
|
// used within tensorexpr/kernel.cpp (only the inputs to the Graph are, not
|
||||||
|
// the inputs to the node) and we would have to redo the mapping to compute
|
||||||
|
// symbolic shapes
|
||||||
|
|
||||||
|
Node* typecheck_node =
|
||||||
|
guarded_node->owningGraph()
|
||||||
|
->create(Symbol::prim("TensorExprDynamicGuard"), inputs_to_check, 1)
|
||||||
|
->insertBefore(guarded_node);
|
||||||
|
|
||||||
|
typecheck_node->tys_(attr::types, guard_types);
|
||||||
|
Value* typecheck_result = typecheck_node->output()->setType(BoolType::get());
|
||||||
|
|
||||||
|
// Insert if
|
||||||
|
auto versioning_if =
|
||||||
|
guarded_node->owningGraph()
|
||||||
|
->create(prim::If, {typecheck_result}, guarded_node->outputs().size())
|
||||||
|
->insertAfter(typecheck_node);
|
||||||
|
|
||||||
|
for (size_t idx = 0; idx < guarded_node->outputs().size(); ++idx) {
|
||||||
|
versioning_if->output(idx)->setType(guarded_node->output(idx)->type());
|
||||||
|
guarded_node->output(idx)->replaceAllUsesWith(versioning_if->output(idx));
|
||||||
|
}
|
||||||
|
auto true_block = versioning_if->addBlock();
|
||||||
|
auto false_block = versioning_if->addBlock();
|
||||||
|
|
||||||
|
// Fill in the false block. It should contain the unoptimized
|
||||||
|
// copy of the fused subgraph.
|
||||||
|
WithInsertPoint guard(false_block->return_node());
|
||||||
|
const auto subgraph_outputs = insertGraph(
|
||||||
|
*guarded_node->owningGraph(), *subgraph, guarded_node->inputs());
|
||||||
|
for (Value* output : subgraph_outputs) {
|
||||||
|
false_block->registerOutput(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// types get copied to the fallback graph, so remove specializations before
|
||||||
|
// replacing
|
||||||
|
removeTensorTypeSpecializations(false_block);
|
||||||
|
replaceBlockWithFallbackGraph(false_block, guarded_node->inputs());
|
||||||
|
|
||||||
|
// Fill in the true block. It has all inputs type-checked and its
|
||||||
|
// body should be the fusion group node.
|
||||||
|
guarded_node->moveBefore(true_block->return_node());
|
||||||
|
|
||||||
|
for (Value* output : guarded_node->outputs()) {
|
||||||
|
true_block->registerOutput(output);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert Symbolic Shapes Compute and add as inputs to TE Node/Graph
|
||||||
|
// symbolic_shape_inputs will be a list of each symbolic shape,
|
||||||
|
// and the last N inputs to TE Graph/Node will be the N
|
||||||
|
// symbolic shape values
|
||||||
|
auto map = InsertSymbolicShapesCompute(shape_mapping, guarded_node);
|
||||||
|
std::vector<int64_t> symbolic_shape_inputs;
|
||||||
|
for (const auto& pair : map) {
|
||||||
|
symbolic_shape_inputs.push_back(pair.first);
|
||||||
|
guarded_node->addInput(pair.second);
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "SS_" << -pair.first;
|
||||||
|
subgraph->addInput(ss.str())->setType(IntType::get());
|
||||||
|
}
|
||||||
|
guarded_node->is_(attr::symbolic_shape_inputs, symbolic_shape_inputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// On each invocation of this guard, we need to check all of the static
|
||||||
|
// information (dtype/device/requires grad/contiguity/static dims),
|
||||||
|
// and also the that the symbolic shape dimensions are observed.
|
||||||
|
// For any symbolic dimension we need to set its value on its first
|
||||||
|
// use and for all subsequent uses check that the values are equal
|
||||||
|
RegisterOperators reg_guard({
|
||||||
|
Operator(
|
||||||
|
"prim::TensorExprDynamicGuard(...) -> bool",
|
||||||
|
[](const Node* node) -> Operation {
|
||||||
|
const auto& types = node->tys(attr::types);
|
||||||
|
|
||||||
|
// Each inputs expected # of dims
|
||||||
|
std::vector<size_t> expected_dims;
|
||||||
|
|
||||||
|
// A flattened vector of all the expected values for all
|
||||||
|
// tensor dims. A positive value corresponds to a static
|
||||||
|
// shape to check and a negative value corresponds to symbolic
|
||||||
|
// dimension index to check
|
||||||
|
std::vector<int64_t> flattened_input_dims;
|
||||||
|
|
||||||
|
// Each inputs expected scalar types
|
||||||
|
std::vector<c10::ScalarType> expected_scalar_types;
|
||||||
|
|
||||||
|
// Map from symbolic dimension value to its set's index
|
||||||
|
std::map<int64_t, size_t> sym_dim_flat_index;
|
||||||
|
TORCH_INTERNAL_ASSERT(types.size() >= 1);
|
||||||
|
|
||||||
|
// we should just be fusing fusion groups with a single device
|
||||||
|
// and with tensors not requiring grad
|
||||||
|
auto maybe_device = types[0]->expect<TensorType>()->device();
|
||||||
|
TORCH_INTERNAL_ASSERT(maybe_device);
|
||||||
|
auto device = *maybe_device;
|
||||||
|
|
||||||
|
for (auto type : types) {
|
||||||
|
auto tt = type->expect<TensorType>();
|
||||||
|
auto ss = tt->symbolic_sizes();
|
||||||
|
TORCH_INTERNAL_ASSERT(ss.rank());
|
||||||
|
expected_dims.push_back(*ss.rank());
|
||||||
|
TORCH_INTERNAL_ASSERT(tt->scalarType());
|
||||||
|
expected_scalar_types.push_back(*tt->scalarType());
|
||||||
|
TORCH_INTERNAL_ASSERT(tt->device() && *tt->device() == device);
|
||||||
|
for (size_t i = 0; i < *ss.rank(); ++i) {
|
||||||
|
auto sym_dim = ss[i];
|
||||||
|
auto value = sym_dim.value();
|
||||||
|
if (value >= 0) {
|
||||||
|
flattened_input_dims.push_back(value);
|
||||||
|
} else {
|
||||||
|
// use index for set if it exists, otherwise extend the vector
|
||||||
|
// of sym shapes by 1
|
||||||
|
int64_t sym_dim_index;
|
||||||
|
if (sym_dim_flat_index.count(value)) {
|
||||||
|
sym_dim_index = sym_dim_flat_index[value];
|
||||||
|
} else {
|
||||||
|
sym_dim_flat_index[value] = (-1) - sym_dim_flat_index.size();
|
||||||
|
sym_dim_index = sym_dim_flat_index[value];
|
||||||
|
}
|
||||||
|
// TODO: potential optimization - if there is a Symbolic
|
||||||
|
// Sym with only one use we dont need to test anything
|
||||||
|
flattened_input_dims.push_back(sym_dim_index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const auto num_inputs = types.size();
|
||||||
|
const auto num_symbolic_dims = sym_dim_flat_index.size();
|
||||||
|
return [num_inputs,
|
||||||
|
expected_dims,
|
||||||
|
device,
|
||||||
|
expected_scalar_types,
|
||||||
|
flattened_input_dims,
|
||||||
|
num_symbolic_dims](Stack& stack) {
|
||||||
|
at::ArrayRef<IValue> inputs = last(stack, num_inputs);
|
||||||
|
drop(stack, num_inputs);
|
||||||
|
// each invocation we need to reset what value of each symbolic
|
||||||
|
// symbol is.
|
||||||
|
// TODO: could this be a reference and not allocated on
|
||||||
|
// each invocation or would that mess up with multithreaded
|
||||||
|
// inference since we are writing to it?
|
||||||
|
// TODO - smallvector here ?
|
||||||
|
std::vector<int64_t> flattened_symbolic_dims(num_symbolic_dims, -1);
|
||||||
|
size_t flattened_dim_offset = 0;
|
||||||
|
for (const auto i : c10::irange(num_inputs)) {
|
||||||
|
at::Tensor tensor = inputs[i].toTensor();
|
||||||
|
if (C10_UNLIKELY(
|
||||||
|
tensor.device() != device ||
|
||||||
|
tensor.dtype() != expected_scalar_types[i]) ||
|
||||||
|
tensor.requires_grad()) {
|
||||||
|
push(stack, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// TODO: striding
|
||||||
|
if (C10_UNLIKELY(
|
||||||
|
!tensor.is_contiguous(at::MemoryFormat::Contiguous))) {
|
||||||
|
push(stack, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto& sizes = tensor.sizes();
|
||||||
|
const auto num_dims = sizes.size();
|
||||||
|
if (C10_UNLIKELY(num_dims != expected_dims[i])) {
|
||||||
|
push(stack, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto dim_index : c10::irange(num_dims)) {
|
||||||
|
const int64_t dim_value =
|
||||||
|
flattened_input_dims[dim_index + flattened_dim_offset];
|
||||||
|
const int64_t tensor_dim = sizes[dim_index];
|
||||||
|
if (dim_value >= 0) {
|
||||||
|
if (C10_UNLIKELY(dim_value != tensor_dim)) {
|
||||||
|
push(stack, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// flattened sym indices start at -1,
|
||||||
|
// so -1 -> index 0, -2 -> index 1
|
||||||
|
const auto flattened_sym_index = (-dim_value) - 1;
|
||||||
|
const auto flattened_sym_value =
|
||||||
|
flattened_symbolic_dims[flattened_sym_index];
|
||||||
|
// sym symbol already seen, check value
|
||||||
|
if (flattened_symbolic_dims[flattened_sym_index] >= 0) {
|
||||||
|
if (C10_UNLIKELY(flattened_sym_value != tensor_dim)) {
|
||||||
|
push(stack, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// not seen, write value
|
||||||
|
flattened_symbolic_dims[flattened_sym_index] = tensor_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flattened_dim_offset += num_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
push(stack, true);
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
},
|
||||||
|
aliasAnalysisFromSchema()),
|
||||||
|
});
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
31
torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
Normal file
31
torch/csrc/jit/passes/symbolic_shape_runtime_fusion.h
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/passes/symbolic_shape_analysis.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace jit {
|
||||||
|
|
||||||
|
// Takes in a TensorExprGraph of static shapes and generalizes the input shapes
|
||||||
|
// to symbolic dimensions. Dimensions of value 1 will be preserved, otherwise
|
||||||
|
// dimensions with the same value will be bucketed to the same symbolic shape.
|
||||||
|
// E.g. Tensor(5, 3), Tensor(3, 1) -> Tensor(SS(-1), SS(-2)), Tensor(SS(-2), 1)
|
||||||
|
// From there, runs symbolic shape inference on the graph, and creates a
|
||||||
|
// versionining if in the graph with prim::TensorExprDynamicGuard checking if
|
||||||
|
// the inputs at runtime match the Generalized Symbolic Shapes that are inputs
|
||||||
|
// to the TE Kernel. The computate to calculate all symbolic dimensions is
|
||||||
|
// inlined in to the if block with the TE Kernel. All Sym Dim Value* are
|
||||||
|
// appended to the end of the TE Kernel Graph/Node inputs, and the Node is
|
||||||
|
// augmented with a integer list attr `symbolic_shape_inputs` that gives the
|
||||||
|
// mapping from Value * -> Symbolic Shape int64_t value. For more lengthy IR
|
||||||
|
// examples and walkthrough look at ShapeAnalysisTest.DynamicShapesFusion in
|
||||||
|
// `test_shape_analysis` Returns True on Success, False on Failure, can fail if
|
||||||
|
// shape propagation fails to propagate # of dims or if complete shapes on
|
||||||
|
// inputs not set
|
||||||
|
|
||||||
|
TORCH_API bool GenerateGuard(Node* tensorexpr_graph_node);
|
||||||
|
|
||||||
|
} // namespace jit
|
||||||
|
} // namespace torch
|
||||||
Loading…
Reference in New Issue
Block a user