mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add RefineTypes JIT pass for Tuple (#76919)
Consider the following JIT graph, where the type of `%a` and `%b` are out of sync with tuple `%c`.
Before:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
return (%c)
```
After:
```
graph(%a : Float(123), %b : Float(4, 5, 6)):
c : (Float(123), Float(4, 5, 6)) = prim::TupleConstruct(%a, %b)
return (%c)
```
This PR adds a pass `RefineTypes(...)` to update all such instances with the correct type. This is also available via Python by using `torch._C._jit_pass_refine_types(...)`.
A unit test has been added for unnamed tuples, but no test exists for `NamedTuple` (though it was tested manually) since it isn't supported by the parser:
```
RuntimeError:
unknown type specifier:
graph(%a : Float(123), %b : Float(4, 5, 6)):
%c : NamedTuple(Tensor : Tuple, Tensor : Tuple) = prim::TupleConstruct(%a, %b)
~~~~~~~~~~ <--- HERE
return (%c)
```
cc: @ke1337 @antoniojkim @wconstab @eellison
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76919
Approved by: https://github.com/eellison
This commit is contained in:
parent
2881e0ea17
commit
f6eb811786
|
|
@ -11252,6 +11252,21 @@ dedent """
|
|||
self.run_pass("erase_number_types", graph)
|
||||
FileCheck().check_not("int = prim::Constant").run(str(graph))
|
||||
|
||||
def test_refine_tuple_types(self):
|
||||
# TupleConstruct output type is not correct here.
|
||||
graph_str = """
|
||||
graph(%a : Float(123), %b : Float(4, 5, 6)):
|
||||
%c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b)
|
||||
return (%c)
|
||||
"""
|
||||
graph = parse_ir(graph_str)
|
||||
torch._C._jit_pass_refine_tuple_types(graph)
|
||||
|
||||
# After the pass, the output type should've been updated.
|
||||
self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output()))
|
||||
|
||||
# TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser.
|
||||
|
||||
def test_remove_dropout(self):
|
||||
weight_0_shape = (20, 5)
|
||||
weight_1_shape = (20, 20)
|
||||
|
|
|
|||
|
|
@ -257,6 +257,7 @@ core_sources_full_mobile_no_backend_interface = [
|
|||
"torch/csrc/jit/passes/peephole.cpp",
|
||||
"torch/csrc/jit/passes/peephole_non_tensor.cpp",
|
||||
"torch/csrc/jit/passes/create_functional_graphs.cpp",
|
||||
"torch/csrc/jit/passes/refine_tuple_types.cpp",
|
||||
"torch/csrc/jit/passes/remove_mutation.cpp",
|
||||
"torch/csrc/jit/passes/prepack_folding.cpp",
|
||||
"torch/csrc/jit/passes/fold_conv_bn.cpp",
|
||||
|
|
|
|||
42
torch/csrc/jit/passes/refine_tuple_types.cpp
Normal file
42
torch/csrc/jit/passes/refine_tuple_types.cpp
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
||||
#include <torch/csrc/jit/runtime/graph_iterator.h>
|
||||
|
||||
#include <ATen/core/type_factory.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
static void VisitTupleNode(Node* node) {
|
||||
TORCH_CHECK(
|
||||
node->outputs().size() == 1, "Tuple must have exactly one output!");
|
||||
|
||||
Value* output = node->outputs()[0];
|
||||
auto tuple_type = output->type()->expectRef<TupleType>();
|
||||
|
||||
TORCH_CHECK(
|
||||
tuple_type.containedTypes().size() == node->inputs().size(),
|
||||
"Number of contained types does not match number of inputs!");
|
||||
|
||||
// Extract updated types from input values.
|
||||
std::vector<c10::TypePtr> types;
|
||||
for (const Value* input : node->inputs()) {
|
||||
types.push_back(input->type());
|
||||
}
|
||||
|
||||
// Construct new tuple type based on input types.
|
||||
output->setType(tuple_type.withContained(types));
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void RefineTupleTypes(std::shared_ptr<Graph>& graph) {
|
||||
DepthFirstGraphNodeIterator it(graph);
|
||||
for (auto* node = it.next(); node != nullptr; node = it.next()) {
|
||||
if (node->kind() == prim::TupleConstruct) {
|
||||
VisitTupleNode(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
12
torch/csrc/jit/passes/refine_tuple_types.h
Normal file
12
torch/csrc/jit/passes/refine_tuple_types.h
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// updates the types of tuples according to the type of their current inputs.
|
||||
TORCH_API void RefineTupleTypes(std::shared_ptr<Graph>& graph);
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
|
@ -58,6 +58,7 @@
|
|||
#include <torch/csrc/jit/passes/quantization/insert_observers.h>
|
||||
#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
|
||||
#include <torch/csrc/jit/passes/quantization/quantization_type.h>
|
||||
#include <torch/csrc/jit/passes/refine_tuple_types.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/remove_expands.h>
|
||||
#include <torch/csrc/jit/passes/remove_inplace_ops.h>
|
||||
|
|
@ -899,6 +900,9 @@ void initJITBindings(PyObject* module) {
|
|||
.def(
|
||||
"_jit_pass_remove_dropout",
|
||||
[](script::Module& module) { return removeDropout(module); })
|
||||
.def(
|
||||
"_jit_pass_refine_tuple_types",
|
||||
[](std::shared_ptr<Graph>& graph) { return RefineTupleTypes(graph); })
|
||||
.def(
|
||||
"_jit_pass_transform_conv1d_to_conv2d",
|
||||
[](std::shared_ptr<Graph>& graph) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user