Add support for cat in output stitching (#66098)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66098

`cat` is somewhat special-cased right now because currently we only have list of Tensor inputs where the list is constructed in the JIT IR graph. While that is generally true for Fusion (e.g. why we have ConstantChunk) that may not be true for shape analysis generally, so I'm waiting a bit to generalize.

Test Plan: Imported from OSS

Reviewed By: navahgar, anjali411

Differential Revision: D31797467

Pulled By: eellison

fbshipit-source-id: ca761e214dfd7f3bba8d189f3b3f42ffec064f63
This commit is contained in:
Elias Ellison 2021-10-20 16:09:33 -07:00 committed by Facebook GitHub Bot
parent 2dd23ebfdb
commit 17889ad26e
2 changed files with 89 additions and 32 deletions

View File

@ -302,20 +302,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
self.assertEqual(o, oe)
def test_partial_eval_stitching(self):
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
conv1_output = conv1(torch.rand(1, 3, 224, 224))
max_pool_ouput = max_pool(conv1_output)
conv2_output = conv2(max_pool_ouput)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs):
g = shape_compute_graph.partial_eval_shape_graph()
self.assertTrue(len(list(g.inputs())) == 1)
self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
# map from sym shape -> index
sym_shape_to_index = {}
@ -324,18 +313,32 @@ class TestSymbolicShapeAnalysis(JitTestCase):
g.makeMultiOutputIntoTuple()
func = torch._C._create_function_from_graph("partial_eval_graph", g)
sym_outputs = func([1, 3, 224, 224])
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
output_shapes = [max_pool_ouput, conv1_output, conv2_output]
sym_outputs = func(*shape_inputs)
for node, output_shape in zip(nodes, output_shapes):
for node, output_shape in zip(nodes, node_output_sizes):
output_type_sizes = node.output().type().symbolic_sizes()
for i, sym_shape in enumerate(output_type_sizes):
if sym_shape >= 0:
self.assertEqual(sym_shape, output_shape.size(i))
self.assertEqual(sym_shape, output_shape[i])
else:
sym_shape_index = sym_shape_to_index[sym_shape]
self.assertEqual(sym_outputs[sym_shape_index], output_shape.size(i))
self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
def test_partial_eval_stitching(self):
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
conv1_output = conv1(torch.rand(1, 3, 224, 224))
max_pool_output = max_pool(conv1_output)
conv2_output = conv2(max_pool_output)
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()]
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],))
def test_refinement_through_graph_stitching(self):
class TwoConvs(torch.nn.Module):
@ -386,3 +389,20 @@ class TestSymbolicShapeAnalysis(JitTestCase):
func = torch._C._create_function_from_graph("partial_eval_graph", g)
output_shape = func(tensor.size())
self.assertEqual(list(output_shape), list(mod(tensor)[0].size()))
def test_stitching_concat(self):
@torch.jit.script
def foo(a, b, x, y):
return (a / b) + torch.cat([x, y])
g = foo.graph
for inp in foo.graph.inputs():
inp.setType(inp.type().with_sizes([None, None]))
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]
inps = [1, 10], [20, 10], [15, 1], [5, 1]
output_shapes = [[20, 10], [20, 10], [20, 1]]
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)

View File

@ -22,6 +22,7 @@
#include <torch/csrc/jit/runtime/exception_message.h>
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
#include <torch/csrc/utils/memory.h>
#include <algorithm>
#include <memory>
#include <numeric>
#include <unordered_map>
@ -151,6 +152,11 @@ bool isListOfInts(const TypePtr& type) {
type->cast<ListType>()->getElementType()->cast<IntType>();
}
bool isListOfTensors(const TypePtr& type) {
return type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<TensorType>();
}
c10::optional<size_t> normIndex(int64_t index, size_t len) {
if (index < 0) {
index = index + len;
@ -237,9 +243,10 @@ struct SymbolicShapeNodeAnalyzer {
if (auto tt = type->castRaw<TensorType>()) {
addTensorInputMetaData(node_->input(node_index), graph_index);
} else if (
type->cast<ListType>() &&
type->cast<ListType>()->getElementType()->cast<TensorType>()) {
} else if (isListOfTensors(type)) {
// waiting for more use cases to decide on best generalization
TORCH_INTERNAL_ASSERT(
node_->kind() == aten::cat, "TODO: generalize logic");
// When we have partially evaluate a list of Tensors like cat(tensor[])
// We have a few problems:
// - optimizing out calls to the length of the list: len(tensors)
@ -635,6 +642,19 @@ struct SymbolicShapeGraphAnalyzer {
if (curr->kind() == prim::Constant) {
continue;
}
// TODO: generalize logic to for other tensor input ops when they are
// added
if (curr->kind() == prim::ListConstruct) {
auto uses = curr->output()->uses();
if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
return use.user->kind() == aten::cat;
})) {
GRAPH_DEBUG("Non cat list use ", getHeader(curr));
return c10::nullopt;
}
continue;
}
if (!partial_evaluated_graphs.count(curr)) {
GRAPH_DEBUG("No graph ", getHeader(curr));
return c10::nullopt;
@ -757,31 +777,48 @@ struct SymbolicShapeGraphAnalyzer {
// cleanup the graph and remove the unneeded complete shapes as outputs,
// leaving us only compute for calculating the runtime value of symbolic
// dimensions
// leaving us only compute for calculating the runtime value of symbolic
// dimensions
std::vector<Value*> inputs;
// NB: node can have more inputs than the shape graph
// so iterate on the # of shape graph inputs
for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
auto node_input = curr->input(i);
std::vector<Value*> node_inputs;
// TODO: generalize logic
if (curr->kind() == aten::cat) {
TORCH_INTERNAL_ASSERT(
curr->input(0)->node()->kind() == prim::ListConstruct);
for (Value* v : curr->input(0)->node()->inputs()) {
node_inputs.push_back(v);
}
node_inputs.push_back(curr->namedInput("dim"));
} else {
for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
node_inputs.push_back(curr->input(i));
}
}
std::vector<Value*> partial_eval_inputs;
for (size_t i = 0; i < node_inputs.size(); ++i) {
auto node_input = node_inputs[i];
auto existing_graph_mapping =
enclosing_graph_value_to_shape_graph_input_.find(curr->input(i));
enclosing_graph_value_to_shape_graph_input_.find(node_input);
if (existing_graph_mapping !=
enclosing_graph_value_to_shape_graph_input_.end()) {
inputs.push_back(existing_graph_mapping->second);
partial_eval_inputs.push_back(existing_graph_mapping->second);
} else {
Value* shape_graph_input =
stitched_shape_compute_graph->addInput()->copyMetadata(
partial_eval_graph->inputs().at(i));
enclosing_graph_value_to_shape_graph_input_[node_input] =
shape_graph_input;
inputs.push_back(shape_graph_input);
partial_eval_inputs.push_back(shape_graph_input);
}
}
WithInsertPoint guard(stitched_shape_compute_graph->block());
std::unordered_map<Value*, Value*> value_map;
insertGraph(
*stitched_shape_compute_graph, *partial_eval_graph, inputs, value_map);
*stitched_shape_compute_graph,
*partial_eval_graph,
partial_eval_inputs,
value_map);
for (size_t i = 0; i < curr->outputs().size(); ++i) {
Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];