mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ONNX] Handle sequence output for models (#50599)
Summary: Duplicate of https://github.com/pytorch/pytorch/issues/46542 Pull Request resolved: https://github.com/pytorch/pytorch/pull/50599 Reviewed By: SplitInfinity Differential Revision: D25928897 Pulled By: bzinodev fbshipit-source-id: a898cef7b2d15a287aedd9798ce1423cebf378d4
This commit is contained in:
parent
c082e2184d
commit
137f2a385a
|
|
@ -56,6 +56,7 @@ pytest "${args[@]}" \
|
||||||
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
--ignore "$top_dir/test/onnx/test_custom_ops.py" \
|
||||||
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
--ignore "$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||||
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
--ignore "$top_dir/test/onnx/test_utility_funs.py" \
|
||||||
|
--ignore "$top_dir/test/onnx/test_pytorch_onnx_caffe2.py" \
|
||||||
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
|
--ignore "$top_dir/test/onnx/test_pytorch_onnx_shape_inference.py" \
|
||||||
"${test_paths[@]}"
|
"${test_paths[@]}"
|
||||||
|
|
||||||
|
|
@ -68,7 +69,8 @@ if [[ "$BUILD_ENVIRONMENT" == *ort_test1* ]]; then
|
||||||
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
|
"$top_dir/test/onnx/test_pytorch_onnx_onnxruntime.py::TestONNXRuntime" \
|
||||||
"$top_dir/test/onnx/test_custom_ops.py" \
|
"$top_dir/test/onnx/test_custom_ops.py" \
|
||||||
"$top_dir/test/onnx/test_models_onnxruntime.py" \
|
"$top_dir/test/onnx/test_models_onnxruntime.py" \
|
||||||
"$top_dir/test/onnx/test_utility_funs.py"
|
"$top_dir/test/onnx/test_utility_funs.py" \
|
||||||
|
"$top_dir/test/onnx/test_pytorch_onnx_caffe2.py"
|
||||||
fi
|
fi
|
||||||
if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
|
if [[ "$BUILD_ENVIRONMENT" == *ort_test2* ]]; then
|
||||||
# Update the loop for new opsets
|
# Update the loop for new opsets
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "13_0"
|
dim_param: "Range13_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,10 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_0"
|
dim_param: "ConstantOfShape2_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_1"
|
dim_param: "ConstantOfShape2_dim_1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -137,10 +137,10 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "10_0"
|
dim_param: "ConstantOfShape10_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "10_1"
|
dim_param: "ConstantOfShape10_dim_1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,10 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_0"
|
dim_param: "ConstantOfShape2_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_1"
|
dim_param: "ConstantOfShape2_dim_1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,10 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_0"
|
dim_param: "ConstantOfShape2_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_1"
|
dim_param: "ConstantOfShape2_dim_1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_0"
|
dim_param: "TopK4_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -80,7 +80,7 @@ graph {
|
||||||
elem_type: 7
|
elem_type: 7
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "5_0"
|
dim_param: "TopK5_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_0"
|
dim_param: "TopK4_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -90,7 +90,7 @@ graph {
|
||||||
elem_type: 7
|
elem_type: 7
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "5_0"
|
dim_param: "TopK5_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "1_0"
|
dim_param: "Unique1_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_value: 3
|
dim_value: 3
|
||||||
|
|
@ -73,7 +73,7 @@ graph {
|
||||||
elem_type: 7
|
elem_type: 7
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_0"
|
dim_param: "Unique4_dim_0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,16 +50,16 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_0"
|
dim_param: "Upsample4_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_1"
|
dim_param: "Upsample4_dim_1"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_2"
|
dim_param: "Upsample4_dim_2"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_3"
|
dim_param: "Upsample4_dim_3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -50,16 +50,16 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_0"
|
dim_param: "Upsample4_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_1"
|
dim_param: "Upsample4_dim_1"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_2"
|
dim_param: "Upsample4_dim_2"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "4_3"
|
dim_param: "Upsample4_dim_3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,10 +47,10 @@ graph {
|
||||||
elem_type: 1
|
elem_type: 1
|
||||||
shape {
|
shape {
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_0"
|
dim_param: "ConstantOfShape2_dim_0"
|
||||||
}
|
}
|
||||||
dim {
|
dim {
|
||||||
dim_param: "2_1"
|
dim_param: "ConstantOfShape2_dim_1"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,13 @@ def convert_to_onnx(model, input=None, opset_version=9, example_outputs=None,
|
||||||
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
ort_sess = onnxruntime.InferenceSession(f.getvalue())
|
||||||
return ort_sess
|
return ort_sess
|
||||||
|
|
||||||
|
|
||||||
|
def inline_flatten_list(inputs, res_list):
|
||||||
|
for i in inputs:
|
||||||
|
res_list.append(i) if not isinstance(i, (list, tuple)) else inline_flatten_list(i, res_list)
|
||||||
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
def run_ort(ort_sess, input):
|
def run_ort(ort_sess, input):
|
||||||
input_copy = copy.deepcopy(input)
|
input_copy = copy.deepcopy(input)
|
||||||
input, _ = torch.jit._flatten(input_copy)
|
input, _ = torch.jit._flatten(input_copy)
|
||||||
|
|
@ -66,7 +73,8 @@ def run_ort(ort_sess, input):
|
||||||
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(inputs))
|
||||||
ort_outs = ort_sess.run(None, ort_inputs)
|
ort_outs = ort_sess.run(None, ort_inputs)
|
||||||
|
|
||||||
return ort_outs
|
return inline_flatten_list(ort_outs, [])
|
||||||
|
|
||||||
|
|
||||||
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
|
def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
|
||||||
output, _ = torch.jit._flatten(output)
|
output, _ = torch.jit._flatten(output)
|
||||||
|
|
@ -115,7 +123,7 @@ def run_model_test(self, model, batch_size=2, state_dict=None,
|
||||||
output_names=output_names, fixed_batch_size=fixed_batch_size, training=None,
|
output_names=output_names, fixed_batch_size=fixed_batch_size, training=None,
|
||||||
onnx_shape_inference=self.onnx_shape_inference,
|
onnx_shape_inference=self.onnx_shape_inference,
|
||||||
use_new_jit_passes=self.use_new_jit_passes)
|
use_new_jit_passes=self.use_new_jit_passes)
|
||||||
|
# compute onnxruntime output prediction
|
||||||
ort_outs = run_ort(ort_sess, input)
|
ort_outs = run_ort(ort_sess, input)
|
||||||
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
ort_compare_with_pytorch(ort_outs, output, rtol, atol)
|
||||||
|
|
||||||
|
|
@ -3591,27 +3599,24 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
def test_split(self):
|
def test_split(self):
|
||||||
class SplitModel(torch.nn.Module):
|
class SplitModel(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
out1, out2, out3 = input.split([2, 1, 2])
|
return input.split([2, 1, 2]), input.split([3, 2])[0]
|
||||||
return out1, out2, out3
|
|
||||||
|
|
||||||
x = torch.randn(5, 4, 3)
|
x = torch.randn(5, 4, 3)
|
||||||
self.run_test(SplitModel(), x)
|
self.run_test(SplitModel(), x)
|
||||||
|
|
||||||
class SplitModel2(torch.nn.Module):
|
class SplitModel2(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
out1, out2, out3 = input.split([2, 1, 1], -2)
|
return input.split([2, 1, 1], -2), input.split([2, 2], -2)[-1]
|
||||||
return out1, out2, out3
|
|
||||||
|
|
||||||
x = torch.randn(5, 4, 3)
|
x = torch.randn(5, 4, 3)
|
||||||
self.run_test(SplitModel2(), x)
|
self.run_test(SplitModel2(), x)
|
||||||
|
|
||||||
class SplitModel3(torch.nn.Module):
|
class SplitModel3(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
out1, out2, out3 = input.split([2, 1, 2])
|
return input.split([2, 1, 2])
|
||||||
return out3, out1
|
|
||||||
|
|
||||||
x = torch.randn(5, 4, 3)
|
x = torch.randn(5, 4, 3)
|
||||||
self.run_test(torch.jit.script(SplitModel3()), x)
|
self.run_test(SplitModel3(), x)
|
||||||
|
|
||||||
@skipIfUnsupportedOpsetVersion([13])
|
@skipIfUnsupportedOpsetVersion([13])
|
||||||
@skipIfUnsupportedMinOpsetVersion(11)
|
@skipIfUnsupportedMinOpsetVersion(11)
|
||||||
|
|
@ -3769,7 +3774,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
res2 += 1
|
res2 += 1
|
||||||
res3 = res3 + [arr[i].sum(0, False)]
|
res3 = res3 + [arr[i].sum(0, False)]
|
||||||
res4 += [arr[-1 - i].sum(0, False)]
|
res4 += [arr[-1 - i].sum(0, False)]
|
||||||
return torch.stack(res), torch.stack(res1), res2, torch.stack(res3), torch.stack(res4)
|
return res, res1, res2, torch.stack(res3), torch.stack(res4)
|
||||||
|
|
||||||
model = ListLoopModel()
|
model = ListLoopModel()
|
||||||
inputs = torch.randn(16)
|
inputs = torch.randn(16)
|
||||||
|
|
@ -5508,7 +5513,6 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
|
|
||||||
self.assertRaises(RuntimeError, check_proto)
|
self.assertRaises(RuntimeError, check_proto)
|
||||||
|
|
||||||
@disableScriptTest() # dtype mismatch
|
|
||||||
def test_split_tensor_scalar(self):
|
def test_split_tensor_scalar(self):
|
||||||
class SplitModel(torch.nn.Module):
|
class SplitModel(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -5861,6 +5865,37 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
|
[np.testing.assert_allclose(ort_out1, ort_out2, atol=1e-7, rtol=0.001) for ort_out1, ort_out2 in
|
||||||
zip(ort_outs1, ort_outs2)]
|
zip(ort_outs1, ort_outs2)]
|
||||||
|
|
||||||
|
def test_script_custom_class_error(self):
|
||||||
|
class BoxCoder(object):
|
||||||
|
def __init__(self, bbox_xform_clip):
|
||||||
|
# type: (float) -> None
|
||||||
|
self.bbox_xform_clip = bbox_xform_clip
|
||||||
|
|
||||||
|
def decode(self, rel_codes, boxes):
|
||||||
|
# type: (Tensor, List[Tensor]) -> Tensor
|
||||||
|
boxes = torch.cat(boxes, dim=0)
|
||||||
|
pred_ctr_x = torch.clamp(rel_codes[:, 0::4], max=self.bbox_xform_clip) * boxes[:, 2]
|
||||||
|
return pred_ctr_x
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
__annotations__ = {
|
||||||
|
'box_coder': BoxCoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModule, self).__init__()
|
||||||
|
self.box_coder = BoxCoder(1.4)
|
||||||
|
|
||||||
|
def forward(self, box_regression: torch.Tensor, proposals: List[torch.Tensor]):
|
||||||
|
return self.box_coder.decode(box_regression, proposals)
|
||||||
|
|
||||||
|
model = torch.jit.script(MyModule())
|
||||||
|
box_regression = torch.randn([4, 4])
|
||||||
|
proposal = [torch.randn(2, 4), torch.randn(2, 4)]
|
||||||
|
|
||||||
|
with self.assertRaises(RuntimeError) as cm:
|
||||||
|
self.run_test(model, (box_regression, proposal))
|
||||||
|
|
||||||
@skipIfUnsupportedOpsetVersion([13])
|
@skipIfUnsupportedOpsetVersion([13])
|
||||||
def test_initializer_sequence(self):
|
def test_initializer_sequence(self):
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -257,7 +257,7 @@ def _replace_overloaded_method_decl(overload_decl: Decl, implementation_def: Def
|
||||||
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
def _jit_pass_lower_all_tuples(graph: Graph) -> None: ...
|
||||||
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
|
def _jit_pass_onnx_set_dynamic_input_shape(graph: Graph, dynamic_axes: Dict[str, Dict[_int, str]], input_names: List[str]) -> None: ...
|
||||||
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
|
def _jit_pass_onnx_graph_shape_type_inference(graph: Graph, opset_version: _int) -> None: ...
|
||||||
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], onnx_shape_inference: _bool = False) -> None: ...
|
def _jit_pass_onnx_assign_output_shape(graph: Graph, tensors: List[Tensor], desc: IODescriptor, onnx_shape_inference: _bool = False) -> None: ...
|
||||||
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
def _jit_pass_onnx_remove_inplace_ops_for_onnx(graph: Graph) -> None: ...
|
||||||
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
def _jit_pass_remove_inplace_ops(graph: Graph) -> None: ...
|
||||||
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
def _jit_pass_canonicalize_graph_fuser_ops(graph: Graph) -> None: ...
|
||||||
|
|
|
||||||
|
|
@ -34,9 +34,10 @@ std::deque<std::string> findSubModuleAttr(
|
||||||
if (node->kind() == prim::GetAttr) {
|
if (node->kind() == prim::GetAttr) {
|
||||||
moduleNames.push_front(node->s(attr::name));
|
moduleNames.push_front(node->s(attr::name));
|
||||||
node = node->inputs()[0]->node();
|
node = node->inputs()[0]->node();
|
||||||
|
} else {
|
||||||
|
return moduleNames;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign the inner module to attrModule.
|
// Assign the inner module to attrModule.
|
||||||
for (auto& moduleName : moduleNames) {
|
for (auto& moduleName : moduleNames) {
|
||||||
attrModule = attrModule.attr(moduleName).toModule();
|
attrModule = attrModule.attr(moduleName).toModule();
|
||||||
|
|
@ -127,24 +128,16 @@ std::vector<IValue> getParamAttributes(
|
||||||
paramConst = addParamAsArgument(function_, fullName, attr);
|
paramConst = addParamAsArgument(function_, fullName, attr);
|
||||||
} else if (
|
} else if (
|
||||||
attr.isObject() && !attr.toObjectRef().type()->is_module()) {
|
attr.isObject() && !attr.toObjectRef().type()->is_module()) {
|
||||||
// Only below registered torch classes are supported.
|
try {
|
||||||
auto type = attr.type();
|
attrValues.emplace_back(
|
||||||
TORCH_CHECK(
|
script::Object(attr.toObject()).run_method("__getstate__"));
|
||||||
(type ==
|
paramConst = addParamAsArgument(function_, fullName, attr);
|
||||||
getCustomClass(
|
} catch (const std::exception&) {
|
||||||
"__torch__.torch.classes.quantized.Conv2dPackedParamsBase")) ||
|
auto type = attr.type();
|
||||||
(type ==
|
throw ErrorReport(n->sourceRange())
|
||||||
getCustomClass(
|
<< "Unknown type " << type->repr_str()
|
||||||
"__torch__.torch.classes.quantized.Conv3dPackedParamsBase")) ||
|
<< " encountered in handling model params. This class type does not extend __getstate__ method.";
|
||||||
(type ==
|
}
|
||||||
getCustomClass(
|
|
||||||
"__torch__.torch.classes.quantized.LinearPackedParamsBase")),
|
|
||||||
"Unknown type ",
|
|
||||||
type->repr_str(),
|
|
||||||
" encountered in handling model params. This type is not supported in ONNX export.");
|
|
||||||
attrValues.emplace_back(
|
|
||||||
script::Object(attr.toObject()).run_method("__getstate__"));
|
|
||||||
paramConst = addParamAsArgument(function_, fullName, attr);
|
|
||||||
} else if (attr.isNone() || name == "training") {
|
} else if (attr.isNone() || name == "training") {
|
||||||
auto attrVal = tryInsertConstant(*graph, attr);
|
auto attrVal = tryInsertConstant(*graph, attr);
|
||||||
paramConst = *attrVal;
|
paramConst = *attrVal;
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@
|
||||||
#include <torch/csrc/jit/passes/onnx/fold_if_node.h>
|
#include <torch/csrc/jit/passes/onnx/fold_if_node.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/helper.h>
|
#include <torch/csrc/jit/passes/onnx/helper.h>
|
||||||
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
|
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
|
||||||
|
#include <torch/csrc/jit/python/python_arg_flatten.h>
|
||||||
#include <torch/csrc/jit/serialization/export.h>
|
#include <torch/csrc/jit/serialization/export.h>
|
||||||
#include <torch/csrc/jit/serialization/onnx.h>
|
#include <torch/csrc/jit/serialization/onnx.h>
|
||||||
|
|
||||||
|
|
@ -552,19 +553,141 @@ void ONNXSetDynamicInputShape(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasSequenceTypeOutput(Node* node) {
|
||||||
|
if (node->kind() == ::c10::onnx::SplitToSequence ||
|
||||||
|
node->kind() == ::c10::onnx::SequenceInsert ||
|
||||||
|
node->kind() == ::c10::onnx::SequenceEmpty ||
|
||||||
|
node->kind() == ::c10::onnx::SequenceErase ||
|
||||||
|
node->kind() == ::c10::onnx::SequenceConstruct)
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ONNXUpdateTypeFromTensor(
|
||||||
|
Value* graph_output,
|
||||||
|
const at::Tensor& output,
|
||||||
|
bool onnx_shape_inference) {
|
||||||
|
if (onnx_shape_inference) {
|
||||||
|
graph_output->setType(
|
||||||
|
MergeInferredType(TensorType::create(output), graph_output->type()));
|
||||||
|
} else {
|
||||||
|
graph_output->inferTypeFrom(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ONNXAssignOutputShape(
|
void ONNXAssignOutputShape(
|
||||||
std::shared_ptr<Graph>& graph,
|
std::shared_ptr<Graph>& graph,
|
||||||
at::ArrayRef<at::Tensor> outputs,
|
at::ArrayRef<at::Tensor> outputs,
|
||||||
|
const python::IODescriptor& desc,
|
||||||
bool onnx_shape_inference) {
|
bool onnx_shape_inference) {
|
||||||
TORCH_INTERNAL_ASSERT(graph->outputs().size() == outputs.size());
|
size_t outputs_index = 0;
|
||||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
|
||||||
if (onnx_shape_inference) {
|
PyObject* py_obj = unflatten(outputs, desc);
|
||||||
graph->outputs()[i]->setType(MergeInferredType(
|
TORCH_INTERNAL_ASSERT(PyTuple_Check(py_obj));
|
||||||
TensorType::create(outputs[i]), graph->outputs()[i]->type()));
|
|
||||||
} else {
|
for (size_t i = 0; i < PyTuple_GET_SIZE(py_obj); ++i) {
|
||||||
graph->outputs()[i]->inferTypeFrom(outputs[i]);
|
PyObject* elem = PyTuple_GET_ITEM(py_obj, i);
|
||||||
|
|
||||||
|
if (PyList_Check(elem)) {
|
||||||
|
size_t list_len = PyList_GET_SIZE(elem);
|
||||||
|
if (HasSequenceTypeOutput(graph->outputs()[outputs_index]->node())) {
|
||||||
|
if (list_len > 0) {
|
||||||
|
auto& var =
|
||||||
|
reinterpret_cast<THPVariable*>(PyList_GET_ITEM(elem, 0))->cdata;
|
||||||
|
for (size_t j = 1; j < list_len; ++j) {
|
||||||
|
PyObject* list_elem = PyList_GET_ITEM(elem, j);
|
||||||
|
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
|
||||||
|
auto& new_var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
|
||||||
|
TORCH_CHECK(
|
||||||
|
var.scalar_type() == new_var.scalar_type(),
|
||||||
|
"Unsupported sequence type in model outputs. ONNX supports sequences of elements of the same data type.");
|
||||||
|
}
|
||||||
|
auto elem_type = graph->outputs()[outputs_index]
|
||||||
|
->type()
|
||||||
|
->cast<ListType>()
|
||||||
|
->getElementType()
|
||||||
|
->cast<TensorType>();
|
||||||
|
elem_type = elem_type->withScalarType(var.scalar_type());
|
||||||
|
graph->outputs()[outputs_index]->setType(MergeInferredType(
|
||||||
|
graph->outputs()[outputs_index]->type(),
|
||||||
|
ListType::create(elem_type)));
|
||||||
|
outputs_index++;
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index <= graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
|
}
|
||||||
|
} else { // When torch output is a list type, but ONNX node is not a
|
||||||
|
// sequence type. Like prim::ListConstruct
|
||||||
|
size_t list_len = PyList_GET_SIZE(elem);
|
||||||
|
if (list_len > 0) {
|
||||||
|
for (size_t j = 0; j < list_len; ++j) {
|
||||||
|
PyObject* list_elem = PyList_GET_ITEM(elem, j);
|
||||||
|
TORCH_INTERNAL_ASSERT(THPVariable_Check(list_elem));
|
||||||
|
auto& var = reinterpret_cast<THPVariable*>(list_elem)->cdata;
|
||||||
|
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||||
|
graph->outputs()[outputs_index + j]->type(),
|
||||||
|
TensorType::create(var)));
|
||||||
|
}
|
||||||
|
outputs_index += list_len;
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index <= graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (PyTuple_Check(elem)) {
|
||||||
|
size_t tuple_len = PyTuple_GET_SIZE(elem);
|
||||||
|
if (tuple_len > 0) {
|
||||||
|
for (size_t j = 0; j < tuple_len; ++j) {
|
||||||
|
PyObject* tuple_elem = PyTuple_GET_ITEM(elem, j);
|
||||||
|
TORCH_INTERNAL_ASSERT(THPVariable_Check(tuple_elem));
|
||||||
|
auto& var = reinterpret_cast<THPVariable*>(tuple_elem)->cdata;
|
||||||
|
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||||
|
graph->outputs()[outputs_index + j]->type(),
|
||||||
|
TensorType::create(var)));
|
||||||
|
}
|
||||||
|
outputs_index += tuple_len;
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index <= graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
|
}
|
||||||
|
} else if (THPVariable_Check(elem)) {
|
||||||
|
at::Tensor var = reinterpret_cast<THPVariable*>(elem)->cdata;
|
||||||
|
ONNXUpdateTypeFromTensor(
|
||||||
|
graph->outputs()[outputs_index], var, onnx_shape_inference);
|
||||||
|
outputs_index++;
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index <= graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
|
} else { // Dict
|
||||||
|
// Support for dict data type is limited to fixed size dictionaries in
|
||||||
|
// ONNX.
|
||||||
|
// Dictionary values are unrolled and keys are not preserved.
|
||||||
|
TORCH_INTERNAL_ASSERT(PyDict_Check(elem));
|
||||||
|
auto unrolled_dict = py::reinterpret_borrow<py::list>(PyDict_Items(elem));
|
||||||
|
TORCH_INTERNAL_ASSERT(PyList_Check(unrolled_dict.ptr()));
|
||||||
|
for (size_t j = 0; j < unrolled_dict.size(); ++j) {
|
||||||
|
PyObject* tuple_elem = PyList_GET_ITEM(unrolled_dict.ptr(), j);
|
||||||
|
TORCH_INTERNAL_ASSERT(PyTuple_Check(tuple_elem));
|
||||||
|
TORCH_INTERNAL_ASSERT(PyTuple_GET_SIZE(tuple_elem) == 2);
|
||||||
|
auto& var =
|
||||||
|
reinterpret_cast<THPVariable*>(PyTuple_GET_ITEM(tuple_elem, 1))
|
||||||
|
->cdata;
|
||||||
|
graph->outputs()[outputs_index + j]->setType(MergeInferredType(
|
||||||
|
graph->outputs()[outputs_index + j]->type(),
|
||||||
|
TensorType::create(var)));
|
||||||
|
}
|
||||||
|
outputs_index += unrolled_dict.size();
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index <= graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(
|
||||||
|
outputs_index == graph->outputs().size(),
|
||||||
|
"Incorrect number of elements provided as example outputs.");
|
||||||
|
|
||||||
|
Py_DECREF(py_obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXShapeTypeInference(std::shared_ptr<Graph>& graph, int opset_version) {
|
void ONNXShapeTypeInference(std::shared_ptr<Graph>& graph, int opset_version) {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <torch/csrc/jit/ir/ir.h>
|
#include <torch/csrc/jit/ir/ir.h>
|
||||||
|
#include <torch/csrc/jit/python/python_arg_flatten.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
|
|
@ -26,6 +27,7 @@ TORCH_API void ONNXSetDynamicInputShape(
|
||||||
TORCH_API void ONNXAssignOutputShape(
|
TORCH_API void ONNXAssignOutputShape(
|
||||||
std::shared_ptr<Graph>& graph,
|
std::shared_ptr<Graph>& graph,
|
||||||
at::ArrayRef<at::Tensor> outputs,
|
at::ArrayRef<at::Tensor> outputs,
|
||||||
|
const python::IODescriptor& desc,
|
||||||
bool onnx_shape_inference);
|
bool onnx_shape_inference);
|
||||||
|
|
||||||
// Utilize ONNX Shape Inference for node.
|
// Utilize ONNX Shape Inference for node.
|
||||||
|
|
|
||||||
|
|
@ -150,8 +150,9 @@ void initJITBindings(PyObject* module) {
|
||||||
"_jit_pass_onnx_assign_output_shape",
|
"_jit_pass_onnx_assign_output_shape",
|
||||||
[](std::shared_ptr<Graph>& graph,
|
[](std::shared_ptr<Graph>& graph,
|
||||||
const std::vector<at::Tensor>& tensors,
|
const std::vector<at::Tensor>& tensors,
|
||||||
|
const python::IODescriptor& desc,
|
||||||
bool onnx_shape_inference = false) {
|
bool onnx_shape_inference = false) {
|
||||||
ONNXAssignOutputShape(graph, tensors, onnx_shape_inference);
|
ONNXAssignOutputShape(graph, tensors, desc, onnx_shape_inference);
|
||||||
})
|
})
|
||||||
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
|
.def("_jit_pass_lower_all_tuples", LowerAllTuples)
|
||||||
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
|
.def("_jit_pass_onnx_function_substitution", ONNXFunctionCallSubstitution)
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,8 @@ void EncoderBase::EncodeValueInfo(
|
||||||
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
std::unordered_map<int64_t, std::string>>& dynamic_axes) {
|
||||||
std::string name = n->debugName();
|
std::string name = n->debugName();
|
||||||
v->set_name(name);
|
v->set_name(name);
|
||||||
auto tensorTypeToONNXType = [&dynamic_axes, &name, this](
|
|
||||||
|
auto tensorTypeToONNXType = [&dynamic_axes, &name, n, this](
|
||||||
const TensorTypePtr& t,
|
const TensorTypePtr& t,
|
||||||
onnx::TypeProto_Tensor* tensor_type) {
|
onnx::TypeProto_Tensor* tensor_type) {
|
||||||
if (t->dim()) {
|
if (t->dim()) {
|
||||||
|
|
@ -349,7 +350,13 @@ void EncoderBase::EncodeValueInfo(
|
||||||
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
|
shape->mutable_dim(i)->set_dim_value(sizes[i].static_size());
|
||||||
} else {
|
} else {
|
||||||
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
|
if (symbol_dim_map_.find(sizes[i]) == symbol_dim_map_.end()) {
|
||||||
symbol_dim_map_[sizes[i]] = name + "_" + std::to_string(i);
|
if (n->node()->kind() == prim::Param) {
|
||||||
|
symbol_dim_map_[sizes[i]] = name + "_dim_" + std::to_string(i);
|
||||||
|
} else {
|
||||||
|
std::string op_type = n->node()->kind().toUnqualString();
|
||||||
|
symbol_dim_map_[sizes[i]] =
|
||||||
|
op_type + name + "_dim_" + std::to_string(i);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
|
shape->mutable_dim(i)->set_dim_param(symbol_dim_map_[sizes[i]]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -437,7 +437,7 @@ def _model_to_graph(model, args, verbose=False,
|
||||||
args = (args, )
|
args = (args, )
|
||||||
|
|
||||||
if isinstance(example_outputs, torch.Tensor):
|
if isinstance(example_outputs, torch.Tensor):
|
||||||
example_outputs = [example_outputs]
|
example_outputs = (example_outputs,)
|
||||||
|
|
||||||
graph, params, torch_out = _create_jit_graph(model, args,
|
graph, params, torch_out = _create_jit_graph(model, args,
|
||||||
_retain_param_name,
|
_retain_param_name,
|
||||||
|
|
@ -456,14 +456,22 @@ def _model_to_graph(model, args, verbose=False,
|
||||||
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
|
if isinstance(model, torch.jit.ScriptModule) or isinstance(model, torch.jit.ScriptFunction):
|
||||||
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
|
assert example_outputs is not None, "example_outputs must be provided when exporting a ScriptModule or " \
|
||||||
"ScriptFunction."
|
"ScriptFunction."
|
||||||
out_vars, _ = torch.jit._flatten(tuple(example_outputs))
|
if isinstance(example_outputs, list):
|
||||||
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, _onnx_shape_inference)
|
example_outputs = [example_outputs]
|
||||||
|
|
||||||
|
out_vars, desc = torch.jit._flatten(tuple(example_outputs))
|
||||||
|
torch._C._jit_pass_onnx_assign_output_shape(graph, out_vars, desc, _onnx_shape_inference)
|
||||||
|
|
||||||
# NB: ONNX requires complete information about output types, which might be
|
# NB: ONNX requires complete information about output types, which might be
|
||||||
# erased by some optimizations, so we need to set it explicitly again.
|
# erased by some optimizations, so we need to set it explicitly again.
|
||||||
if torch_out is not None:
|
if torch_out is not None:
|
||||||
output_tensors, _ = torch._C._jit_flatten(torch_out)
|
if not (isinstance(torch_out, list) or isinstance(torch_out, tuple)):
|
||||||
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, _onnx_shape_inference)
|
output_wrapped = [torch_out] # type: ignore
|
||||||
|
else:
|
||||||
|
output_wrapped = torch_out # type: ignore
|
||||||
|
|
||||||
|
output_tensors, out_desc = torch._C._jit_flatten(tuple(output_wrapped))
|
||||||
|
torch._C._jit_pass_onnx_assign_output_shape(graph, output_tensors, out_desc, _onnx_shape_inference)
|
||||||
|
|
||||||
_set_input_and_output_names(graph, input_names, output_names)
|
_set_input_and_output_names(graph, input_names, output_names)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user