mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Support tuples in ScriptModule inputs/outputs (#20784)
Summary: - [x] Add tests after https://github.com/pytorch/pytorch/pull/20256 is merged - Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples. - Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass ```_jit_pass_lower_all_tuples```. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in ```_jit_pass_lower_all_tuples```. Handling inputs is easier because all tuple information is encoded within the input tensor type. - Swap the order of ```_jit_pass_lower_all_tuples``` and ```_jit_pass_erase_number_types```. Ops like ```prim::TupleIndex``` relies on index being a scalar. ```_jit_pass_erase_number_types``` will convert these kind of scalars to tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/20784 Reviewed By: zrphercule Differential Revision: D15484171 Pulled By: houseroad fbshipit-source-id: 4767a84038244c929f5662758047af6cb92228d3
This commit is contained in:
parent
4c03ac7ac4
commit
a3db2844e1
|
|
@ -1651,6 +1651,29 @@ class TestCaffe2Backend(unittest.TestCase):
|
|||
|
||||
self.run_model_test(MyModel(), train=False, input=lstm_in, batch_size=3, use_gpu=False)
|
||||
|
||||
def test_tuple_input_output(self):
|
||||
class TupleModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a):
|
||||
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
|
||||
return a
|
||||
|
||||
x = (torch.randn(3, 4), torch.randn(4, 3))
|
||||
self.run_model_test(TupleModel(), train=False, input=(x,), batch_size=BATCH_SIZE,
|
||||
example_outputs=(x,))
|
||||
|
||||
def test_nested_tuple_input_output(self):
|
||||
class NestedTupleModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, a, b):
|
||||
# type: (Tensor, Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
|
||||
return a + b[0] + b[1][0] + b[1][1]
|
||||
|
||||
x = torch.randn(4, 5)
|
||||
y = (torch.randn(4, 5), (torch.randn(4, 5), torch.randn(4, 5)))
|
||||
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
|
||||
example_outputs=x + y[0] + y[1][0] + y[1][1])
|
||||
|
||||
def test_topk(self):
|
||||
class TopKModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
|
|
|||
|
|
@ -178,11 +178,58 @@ static Self moduleSelf(
|
|||
};
|
||||
}
|
||||
|
||||
static void setInputTensorTypes(Graph& g, const Stack& stack) {
|
||||
AT_ASSERT(stack.size() == g.inputs().size());
|
||||
for (size_t i = 0; i < stack.size(); ++i) {
|
||||
g.inputs().at(i)->setType(
|
||||
DimensionedTensorType::create(stack.at(i).toTensor()));
|
||||
static TypePtr getTensorType(
|
||||
const at::Tensor& t,
|
||||
const TypeKind type_kind) {
|
||||
switch (type_kind) {
|
||||
case TypeKind::DimensionedTensorType:
|
||||
return DimensionedTensorType::create(t);
|
||||
case TypeKind::CompleteTensorType: {
|
||||
auto scalar_type = t.scalar_type();
|
||||
auto sizes = t.sizes();
|
||||
return CompleteTensorType::create(scalar_type, at::kCPU, sizes);
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"Attempted to call getTensorType for type kind other than DimensionedTensorType or CompleteTensorType.");
|
||||
}
|
||||
}
|
||||
|
||||
static TupleTypePtr getTupleTensorType(
|
||||
const Stack::const_iterator& s_iter,
|
||||
const Stack::const_iterator& s_iter_end,
|
||||
const TypePtr& tupleType,
|
||||
const TypeKind type_kind) {
|
||||
AT_ASSERT(tupleType->kind() == TupleType::Kind);
|
||||
AT_ASSERT(s_iter != s_iter_end);
|
||||
|
||||
std::vector<TypePtr> types;
|
||||
for (const auto& subType : tupleType->containedTypes()) {
|
||||
if (subType->kind() == TupleType::Kind) {
|
||||
types.push_back(getTupleTensorType(s_iter+1, s_iter_end, subType, type_kind));
|
||||
} else {
|
||||
types.push_back(getTensorType(s_iter->toTensor(), type_kind));
|
||||
}
|
||||
}
|
||||
return TupleType::create(types);
|
||||
}
|
||||
|
||||
static void setInputTensorTypes(
|
||||
Graph& g,
|
||||
const Stack& stack,
|
||||
const TypeKind type_kind = TypeKind::DimensionedTensorType) {
|
||||
at::ArrayRef<Value*> input_values = g.inputs();
|
||||
auto s_iter = stack.begin();
|
||||
for (auto v : input_values) {
|
||||
AT_ASSERT(s_iter != stack.end());
|
||||
if (v->type()->kind() == TupleType::Kind) {
|
||||
AT_ASSERT(v->node()->kind() == prim::Param);
|
||||
v->setType(
|
||||
getTupleTensorType(s_iter, stack.end(), v->type(), type_kind));
|
||||
} else {
|
||||
v->setType(getTensorType(s_iter->toTensor(), type_kind));
|
||||
s_iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -197,38 +244,32 @@ static std::shared_ptr<Graph> _propagate_shapes(
|
|||
return retval;
|
||||
}
|
||||
|
||||
static std::shared_ptr<Graph> _propagate_and_assign_input_and_output_shapes(
|
||||
static std::shared_ptr<Graph> _propagate_and_assign_input_shapes(
|
||||
Graph& graph,
|
||||
std::vector<at::Tensor> inputs,
|
||||
std::vector<at::Tensor> outputs,
|
||||
const std::vector<at::Tensor>& inputs,
|
||||
bool with_grad = false,
|
||||
bool propagate = true) {
|
||||
auto retval = graph.copy();
|
||||
if (propagate) {
|
||||
setInputTensorTypes(*retval, fmap<IValue>(inputs));
|
||||
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::DimensionedTensorType);
|
||||
PropagateInputShapes(retval);
|
||||
}
|
||||
AT_ASSERT(retval->inputs().size() == inputs.size());
|
||||
for (size_t i = 0; i < retval->inputs().size(); ++i) {
|
||||
auto scalar_type = inputs[i].scalar_type();
|
||||
auto sizes = inputs[i].sizes();
|
||||
auto type =
|
||||
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
|
||||
retval->inputs()[i]->setType(type);
|
||||
}
|
||||
at::ArrayRef<Value*> output_values = retval->outputs();
|
||||
// patch this to still work if we are returning a tuple of multiple values
|
||||
if (output_values.at(0)->type()->kind() == TupleType::Kind) {
|
||||
AT_ASSERT(output_values.at(0)->node()->kind() == prim::TupleConstruct);
|
||||
output_values = output_values.at(0)->node()->inputs();
|
||||
}
|
||||
AT_ASSERT(output_values.size() == outputs.size());
|
||||
setInputTensorTypes(*retval, fmap<IValue>(inputs), TypeKind::CompleteTensorType);
|
||||
|
||||
return retval;
|
||||
}
|
||||
|
||||
static std::shared_ptr<Graph> _assign_output_shapes(
|
||||
Graph& graph,
|
||||
std::vector<at::Tensor> outputs) {
|
||||
auto retval = graph.copy();
|
||||
AT_ASSERT(retval->outputs().size() == outputs.size());
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
auto scalar_type = outputs[i].scalar_type();
|
||||
auto sizes = outputs[i].sizes();
|
||||
auto type =
|
||||
torch::jit::CompleteTensorType::create(scalar_type, at::kCPU, sizes);
|
||||
output_values[i]->setType(type);
|
||||
retval->outputs()[i]->setType(type);
|
||||
}
|
||||
return retval;
|
||||
}
|
||||
|
|
@ -679,8 +720,11 @@ void initJitScriptBindings(PyObject* module) {
|
|||
debugSetAutodiffSubgraphInlining);
|
||||
m.def("_propagate_shapes", _propagate_shapes);
|
||||
m.def(
|
||||
"_propagate_and_assign_input_and_output_shapes",
|
||||
_propagate_and_assign_input_and_output_shapes);
|
||||
"_propagate_and_assign_input_shapes",
|
||||
_propagate_and_assign_input_shapes);
|
||||
m.def(
|
||||
"_assign_output_shapes",
|
||||
_assign_output_shapes);
|
||||
m.def("_jit_python_print", [](py::object obj) {
|
||||
std::ostringstream ss;
|
||||
std::vector<at::Tensor> constants;
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import warnings
|
|||
from torch._six import string_classes
|
||||
from torch.jit import _unique_state_dict
|
||||
from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes
|
||||
from torch._C import ListType, _propagate_and_assign_input_and_output_shapes
|
||||
from torch._C import ListType, _propagate_and_assign_input_shapes, _assign_output_shapes
|
||||
|
||||
|
||||
# the flag to tell the user whether it's in the middle of ONNX export or not
|
||||
|
|
@ -214,10 +214,10 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa
|
|||
|
||||
# onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
|
||||
torch._C._jit_pass_prepare_division_for_onnx(graph)
|
||||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
torch._C._jit_pass_erase_number_types(graph)
|
||||
# onnx does not support tuples, so try to remove them
|
||||
torch._C._jit_pass_lower_all_tuples(graph)
|
||||
# onnx only supports tensors, so we turn all out number types into tensors
|
||||
torch._C._jit_pass_erase_number_types(graph)
|
||||
torch._C._jit_pass_peephole(graph, True)
|
||||
torch._C._jit_pass_lint(graph)
|
||||
|
||||
|
|
@ -288,16 +288,18 @@ def _model_to_graph(model, args, verbose=False, training=False,
|
|||
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule"
|
||||
try:
|
||||
method_graph, params = model.forward._lowered_graph()
|
||||
graph = _propagate_and_assign_input_and_output_shapes(
|
||||
method_graph, tuple(args) + tuple(params), example_outputs, False, propagate)
|
||||
in_vars, in_desc = torch.jit._flatten(tuple(args) + tuple(params))
|
||||
graph = _propagate_and_assign_input_shapes(
|
||||
method_graph, tuple(in_vars), False, propagate)
|
||||
except AttributeError:
|
||||
raise RuntimeError('\'forward\' method must be a script method')
|
||||
elif isinstance(model, torch.jit.Function):
|
||||
assert example_outputs is not None, "example_outputs must be provided when exporting a TorchScript Function"
|
||||
method = model
|
||||
params = ()
|
||||
graph = _propagate_and_assign_input_and_output_shapes(
|
||||
model.graph, tuple(args), example_outputs, False, propagate)
|
||||
in_vars, in_desc = torch.jit._flatten(tuple(args))
|
||||
graph = _propagate_and_assign_input_shapes(
|
||||
model.graph, tuple(in_vars), False, propagate)
|
||||
else:
|
||||
graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
|
||||
state_dict = _unique_state_dict(model)
|
||||
|
|
@ -313,6 +315,10 @@ def _model_to_graph(model, args, verbose=False, training=False,
|
|||
graph = _optimize_graph(graph, operator_export_type,
|
||||
_disable_torch_constant_prop=_disable_torch_constant_prop)
|
||||
|
||||
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.Function):
|
||||
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
|
||||
graph = _assign_output_shapes(graph, out_vars)
|
||||
|
||||
# NB: ONNX requires complete information about output types, which might be
|
||||
# erased by some optimizations, so we need to set it explicitly again.
|
||||
if torch_out is not None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user