mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add support for cat in output stitching (#66098)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66098 `cat` is somewhat special-cased right now because currently we only have list of Tensor inputs where the list is constructed in the JIT IR graph. While that is generally true for Fusion (e.g. why we have ConstantChunk) that may not be true for shape analysis generally, so I'm waiting a bit to generalize. Test Plan: Imported from OSS Reviewed By: navahgar, anjali411 Differential Revision: D31797467 Pulled By: eellison fbshipit-source-id: ca761e214dfd7f3bba8d189f3b3f42ffec064f63
This commit is contained in:
parent
2dd23ebfdb
commit
17889ad26e
|
|
@ -302,20 +302,9 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
|||
for o, oe in zip(output, output_eager[0:1] + output_eager[2:]):
|
||||
self.assertEqual(o, oe)
|
||||
|
||||
def test_partial_eval_stitching(self):
|
||||
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
||||
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
||||
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
||||
|
||||
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
|
||||
|
||||
conv1_output = conv1(torch.rand(1, 3, 224, 224))
|
||||
max_pool_ouput = max_pool(conv1_output)
|
||||
conv2_output = conv2(max_pool_ouput)
|
||||
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
def checkSymShapeCompute(self, shape_compute_graph, nodes, node_output_sizes, shape_inputs):
|
||||
g = shape_compute_graph.partial_eval_shape_graph()
|
||||
self.assertTrue(len(list(g.inputs())) == 1)
|
||||
self.assertTrue(len(list(g.inputs())) == len(shape_inputs))
|
||||
output_sym_map = shape_compute_graph.graph_output_to_symbolic_shape_dim()
|
||||
# map from sym shape -> index
|
||||
sym_shape_to_index = {}
|
||||
|
|
@ -324,18 +313,32 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
|||
|
||||
g.makeMultiOutputIntoTuple()
|
||||
func = torch._C._create_function_from_graph("partial_eval_graph", g)
|
||||
sym_outputs = func([1, 3, 224, 224])
|
||||
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
|
||||
output_shapes = [max_pool_ouput, conv1_output, conv2_output]
|
||||
sym_outputs = func(*shape_inputs)
|
||||
|
||||
for node, output_shape in zip(nodes, output_shapes):
|
||||
for node, output_shape in zip(nodes, node_output_sizes):
|
||||
output_type_sizes = node.output().type().symbolic_sizes()
|
||||
for i, sym_shape in enumerate(output_type_sizes):
|
||||
if sym_shape >= 0:
|
||||
self.assertEqual(sym_shape, output_shape.size(i))
|
||||
self.assertEqual(sym_shape, output_shape[i])
|
||||
else:
|
||||
sym_shape_index = sym_shape_to_index[sym_shape]
|
||||
self.assertEqual(sym_outputs[sym_shape_index], output_shape.size(i))
|
||||
self.assertEqual(sym_outputs[sym_shape_index], output_shape[i])
|
||||
|
||||
def test_partial_eval_stitching(self):
|
||||
conv1 = torch.nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
||||
max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
||||
conv2 = nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
|
||||
|
||||
mod = torch.jit.freeze(torch.jit.script(nn.Sequential(conv1, max_pool, conv2).eval()))
|
||||
|
||||
conv1_output = conv1(torch.rand(1, 3, 224, 224))
|
||||
max_pool_output = max_pool(conv1_output)
|
||||
conv2_output = conv2(max_pool_output)
|
||||
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(mod.graph)
|
||||
nodes = [mod.graph.findNode("aten::max_pool2d")] + list(mod.graph.findAllNodes("aten::conv2d"))
|
||||
output_shapes = [max_pool_output.size(), conv1_output.size(), conv2_output.size()]
|
||||
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, ([1, 3, 224, 224],))
|
||||
|
||||
def test_refinement_through_graph_stitching(self):
|
||||
class TwoConvs(torch.nn.Module):
|
||||
|
|
@ -386,3 +389,20 @@ class TestSymbolicShapeAnalysis(JitTestCase):
|
|||
func = torch._C._create_function_from_graph("partial_eval_graph", g)
|
||||
output_shape = func(tensor.size())
|
||||
self.assertEqual(list(output_shape), list(mod(tensor)[0].size()))
|
||||
|
||||
def test_stitching_concat(self):
|
||||
@torch.jit.script
|
||||
def foo(a, b, x, y):
|
||||
return (a / b) + torch.cat([x, y])
|
||||
|
||||
g = foo.graph
|
||||
for inp in foo.graph.inputs():
|
||||
inp.setType(inp.type().with_sizes([None, None]))
|
||||
|
||||
shape_compute_graph = torch._C._jit_pass_propagate_shapes_on_graph_and_build_compute(foo.graph)
|
||||
nodes = [g.findNode("aten::div")] + [g.findNode("aten::add")] + [g.findNode("aten::cat")]
|
||||
|
||||
inps = [1, 10], [20, 10], [15, 1], [5, 1]
|
||||
output_shapes = [[20, 10], [20, 10], [20, 1]]
|
||||
|
||||
self.checkSymShapeCompute(shape_compute_graph, nodes, output_shapes, inps)
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
#include <torch/csrc/jit/runtime/exception_message.h>
|
||||
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
|
||||
#include <torch/csrc/utils/memory.h>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
|
@ -151,6 +152,11 @@ bool isListOfInts(const TypePtr& type) {
|
|||
type->cast<ListType>()->getElementType()->cast<IntType>();
|
||||
}
|
||||
|
||||
bool isListOfTensors(const TypePtr& type) {
|
||||
return type->cast<ListType>() &&
|
||||
type->cast<ListType>()->getElementType()->cast<TensorType>();
|
||||
}
|
||||
|
||||
c10::optional<size_t> normIndex(int64_t index, size_t len) {
|
||||
if (index < 0) {
|
||||
index = index + len;
|
||||
|
|
@ -237,9 +243,10 @@ struct SymbolicShapeNodeAnalyzer {
|
|||
|
||||
if (auto tt = type->castRaw<TensorType>()) {
|
||||
addTensorInputMetaData(node_->input(node_index), graph_index);
|
||||
} else if (
|
||||
type->cast<ListType>() &&
|
||||
type->cast<ListType>()->getElementType()->cast<TensorType>()) {
|
||||
} else if (isListOfTensors(type)) {
|
||||
// waiting for more use cases to decide on best generalization
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
node_->kind() == aten::cat, "TODO: generalize logic");
|
||||
// When we have partially evaluate a list of Tensors like cat(tensor[])
|
||||
// We have a few problems:
|
||||
// - optimizing out calls to the length of the list: len(tensors)
|
||||
|
|
@ -635,6 +642,19 @@ struct SymbolicShapeGraphAnalyzer {
|
|||
if (curr->kind() == prim::Constant) {
|
||||
continue;
|
||||
}
|
||||
// TODO: generalize logic to for other tensor input ops when they are
|
||||
// added
|
||||
if (curr->kind() == prim::ListConstruct) {
|
||||
auto uses = curr->output()->uses();
|
||||
if (!std::all_of(uses.begin(), uses.end(), [](const Use& use) {
|
||||
return use.user->kind() == aten::cat;
|
||||
})) {
|
||||
GRAPH_DEBUG("Non cat list use ", getHeader(curr));
|
||||
return c10::nullopt;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!partial_evaluated_graphs.count(curr)) {
|
||||
GRAPH_DEBUG("No graph ", getHeader(curr));
|
||||
return c10::nullopt;
|
||||
|
|
@ -757,31 +777,48 @@ struct SymbolicShapeGraphAnalyzer {
|
|||
// cleanup the graph and remove the unneeded complete shapes as outputs,
|
||||
// leaving us only compute for calculating the runtime value of symbolic
|
||||
// dimensions
|
||||
// leaving us only compute for calculating the runtime value of symbolic
|
||||
// dimensions
|
||||
|
||||
std::vector<Value*> inputs;
|
||||
// NB: node can have more inputs than the shape graph
|
||||
// so iterate on the # of shape graph inputs
|
||||
for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
|
||||
auto node_input = curr->input(i);
|
||||
std::vector<Value*> node_inputs;
|
||||
// TODO: generalize logic
|
||||
if (curr->kind() == aten::cat) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
curr->input(0)->node()->kind() == prim::ListConstruct);
|
||||
for (Value* v : curr->input(0)->node()->inputs()) {
|
||||
node_inputs.push_back(v);
|
||||
}
|
||||
node_inputs.push_back(curr->namedInput("dim"));
|
||||
} else {
|
||||
for (size_t i = 0; i < partial_eval_graph->inputs().size(); ++i) {
|
||||
node_inputs.push_back(curr->input(i));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Value*> partial_eval_inputs;
|
||||
for (size_t i = 0; i < node_inputs.size(); ++i) {
|
||||
auto node_input = node_inputs[i];
|
||||
auto existing_graph_mapping =
|
||||
enclosing_graph_value_to_shape_graph_input_.find(curr->input(i));
|
||||
enclosing_graph_value_to_shape_graph_input_.find(node_input);
|
||||
if (existing_graph_mapping !=
|
||||
enclosing_graph_value_to_shape_graph_input_.end()) {
|
||||
inputs.push_back(existing_graph_mapping->second);
|
||||
partial_eval_inputs.push_back(existing_graph_mapping->second);
|
||||
} else {
|
||||
Value* shape_graph_input =
|
||||
stitched_shape_compute_graph->addInput()->copyMetadata(
|
||||
partial_eval_graph->inputs().at(i));
|
||||
enclosing_graph_value_to_shape_graph_input_[node_input] =
|
||||
shape_graph_input;
|
||||
inputs.push_back(shape_graph_input);
|
||||
partial_eval_inputs.push_back(shape_graph_input);
|
||||
}
|
||||
}
|
||||
|
||||
WithInsertPoint guard(stitched_shape_compute_graph->block());
|
||||
std::unordered_map<Value*, Value*> value_map;
|
||||
insertGraph(
|
||||
*stitched_shape_compute_graph, *partial_eval_graph, inputs, value_map);
|
||||
*stitched_shape_compute_graph,
|
||||
*partial_eval_graph,
|
||||
partial_eval_inputs,
|
||||
value_map);
|
||||
|
||||
for (size_t i = 0; i < curr->outputs().size(); ++i) {
|
||||
Value* new_list_output = value_map[partial_eval_graph->outputs().at(i)];
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user