Add RefineTypes JIT pass for Tuple (#76919)

Consider the following JIT graph, where the type of `%a` and `%b` are out of sync with tuple `%c`.
Before:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
    c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
    return (%c)
```
After:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
    c : (Float(123), Float(4, 5, 6)) = prim::TupleConstruct(%a, %b)
    return (%c)
```
This PR adds a pass `RefineTypes(...)` to update all such instances with the correct type. This is also available via Python by using `torch._C._jit_pass_refine_types(...)`.

A unit test has been added for unnamed tuples, but no test exists for `NamedTuple` (though it was tested manually) since it isn't supported by the parser:
```
RuntimeError:
unknown type specifier:

        graph(%a : Float(123), %b : Float(4, 5, 6)):
          %c : NamedTuple(Tensor : Tuple, Tensor : Tuple) = prim::TupleConstruct(%a, %b)
               ~~~~~~~~~~ <--- HERE
          return (%c)
```

cc: @ke1337 @antoniojkim @wconstab @eellison
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76919
Approved by: https://github.com/eellison
This commit is contained in:
Henry Tu 2022-05-12 00:48:39 +00:00 committed by PyTorch MergeBot
parent 2881e0ea17
commit f6eb811786
5 changed files with 74 additions and 0 deletions

View File

@ -11252,6 +11252,21 @@ dedent """
self.run_pass("erase_number_types", graph)
FileCheck().check_not("int = prim::Constant").run(str(graph))
def test_refine_tuple_types(self):
# TupleConstruct output type is not correct here.
graph_str = """
graph(%a : Float(123), %b : Float(4, 5, 6)):
%c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
return (%c)
"""
graph = parse_ir(graph_str)
torch._C._jit_pass_refine_tuple_types(graph)
# After the pass, the output type should've been updated.
self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))
# TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.
def test_remove_dropout(self):
weight_0_shape = (20, 5)
weight_1_shape = (20, 20)

View File

@ -257,6 +257,7 @@ core_sources_full_mobile_no_backend_interface = [
"torch/csrc/jit/passes/peephole.cpp",
"torch/csrc/jit/passes/peephole_non_tensor.cpp",
"torch/csrc/jit/passes/create_functional_graphs.cpp",
"torch/csrc/jit/passes/refine_tuple_types.cpp",
"torch/csrc/jit/passes/remove_mutation.cpp",
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",

View File

@ -0,0 +1,42 @@
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/jit/runtime/graph_iterator.h>
#include <ATen/core/type_factory.h>
namespace torch {
namespace jit {
namespace {
static void VisitTupleNode(Node* node) {
TORCH_CHECK(
node->outputs().size() == 1, "Tuple must have exactly one output!");
Value* output = node->outputs()[0];
auto tuple_type = output->type()->expectRef<TupleType>();
TORCH_CHECK(
tuple_type.containedTypes().size() == node->inputs().size(),
"Number of contained types does not match number of inputs!");
// Extract updated types from input values.
std::vector<c10::TypePtr> types;
for (const Value* input : node->inputs()) {
types.push_back(input->type());
}
// Construct new tuple type based on input types.
output->setType(tuple_type.withContained(types));
}
} // anonymous namespace
void RefineTupleTypes(std::shared_ptr<Graph>& graph) {
DepthFirstGraphNodeIterator it(graph);
for (auto* node = it.next(); node != nullptr; node = it.next()) {
if (node->kind() == prim::TupleConstruct) {
VisitTupleNode(node);
}
}
}
} // namespace jit
} // namespace torch

View File

@ -0,0 +1,12 @@
#pragma once
#include <torch/csrc/jit/ir/ir.h>
namespace torch {
namespace jit {
// updates the types of tuples according to the type of their current inputs.
TORCH_API void RefineTupleTypes(std::shared_ptr<Graph>& graph);
} // namespace jit
} // namespace torch

View File

@ -58,6 +58,7 @@
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
#include <torch/csrc/jit/passes/refine_tuple_types.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/remove_expands.h>
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
@ -899,6 +900,9 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_remove_dropout",
[](script::Module& module) { return removeDropout(module); })
.def(
"_jit_pass_refine_tuple_types",
[](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
.def(
"_jit_pass_transform_conv1d_to_conv2d",
[](std::shared_ptr<Graph>& graph) {