ONNX Export Topk with Dynamic k (+ add test cases)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21104

Differential Revision: D16061592

Pulled By: houseroad

fbshipit-source-id: 855b310a138fdde9c25869ffe9f127189dc2eaf5
This commit is contained in:
Lara 2019-07-05 23:41:03 -07:00 committed by Facebook Github Bot
parent 221af09ca7
commit 42c6ea5faa
6 changed files with 62 additions and 14 deletions

View File

@ -94,6 +94,12 @@ bool TopKOp<T, Context>::RunOnDevice() {
auto* indices = Output(1);
auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
int64_t k = k_;
if(k == -1 && InputSize() == 2) {
k = Input(1).template data<int64_t>()[0];
}
CAFFE_ENFORCE(k >= 1, "k argument must be >= 1");
at::IntArrayRef input_dims = input.sizes();
if (axis_ == -1) {
axis_ = input_dims.size() - 1;
@ -102,7 +108,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
CAFFE_ENFORCE_LT(axis_, input_dims.size());
std::vector<int64_t> output_dims = input_dims.vec();
output_dims[axis_] = k_;
output_dims[axis_] = k;
values->Resize(output_dims);
indices->Resize(output_dims);
if (flatten_indices != nullptr) {
@ -134,7 +140,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
int64_t(1),
std::multiplies<int64_t>());
const int64_t src_offset_stride = input_dims[axis_] * next_size;
const int64_t dst_offset_stride = k_ * next_size;
const int64_t dst_offset_stride = k * next_size;
int64_t src_offset = 0;
int64_t dst_offset = 0;
for (int64_t i = 0; i < prev_size; ++i) {
@ -142,7 +148,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
GetTopK(
input_data,
input_dims[axis_],
k_,
k,
src_offset + j,
dst_offset + j,
next_size,
@ -209,7 +215,7 @@ REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
OPERATOR_SCHEMA(TopK)
.NumInputs(1)
.NumInputs(1, 2)
.NumOutputs(2, 3)
.TensorInferenceFunction([](const OperatorDef& def,
const vector<TensorShape>& in) {
@ -235,7 +241,9 @@ OPERATOR_SCHEMA(TopK)
return out;
})
.SetDoc(R"DOC(
Retrieve the top-K elements of the last dimension. Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$ and integer argument `k`, return up to three outputs:
Retrieve the top-K elements of the last dimension.
Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$. `k` can be passed as an integer argument or a 1D tensor containing a single integer.
Returns up to three outputs:
1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension
2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor).
@ -324,6 +332,10 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
0,
"X",
"(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
.Input(
1,
"k",
"(*int*): number of top elements to retrieve")
.Output(
0,
"Values",
@ -335,8 +347,7 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
.Output(
2,
"Flattened_indices",
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`")
.Arg("k", "(*int*): number of top elements to retrieve");
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`");
OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);

View File

@ -17,7 +17,6 @@ class TopKOp : public Operator<Context> {
: Operator<Context>(std::forward<Args>(args)...),
OP_SINGLE_ARG(int, "k", k_, -1),
OP_SINGLE_ARG(int, "axis", axis_, -1) {
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
}
~TopKOp() {}

View File

@ -83,6 +83,23 @@ class TestONNXOpset(TestCase):
x = torch.arange(1., 6., requires_grad=True)
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
# test with dynamic k
class MyModuleDynamic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, input, k):
return torch.topk(input, k)
ops_10 = [{"op_name" : "Unsqueeze", "attributes" : [{"name" : "axes", "ints" : [0], "type" : 7}]},
{"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2}]}]
ops = {10 : ops_10}
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
module = MyModuleDynamic()
example_output = module(x, k)
check_onnx_opsets_operator(module, [x, k], ops,
opset_versions=[10],
example_outputs=example_output)
def test_maxpool(self):
module = torch.nn.MaxPool1d(2, stride=1)

View File

@ -1724,7 +1724,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
example_outputs=x + y[0] + y[1][0] + y[1][1])
@skipIfUnsupportedOpsetVersion([10])
def test_topk(self):
class TopKModel(torch.nn.Module):
def forward(self, input):
@ -1733,7 +1732,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
x = torch.arange(1., 6.)
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
@skipIfUnsupportedOpsetVersion([10])
def test_topk_script(self):
class TopKModel(torch.jit.ScriptModule):
@torch.jit.script_method

View File

@ -9,7 +9,7 @@ import torch
import numpy as np
import io
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
class TestONNXRuntime(unittest.TestCase):
@ -17,7 +17,7 @@ class TestONNXRuntime(unittest.TestCase):
opset_version = _export_onnx_opset_version
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
outputs = model(inputs)
outputs = model(inputs) if isinstance(inputs, torch.Tensor) else model(*inputs)
# export the model to ONNX
f = io.BytesIO()
@ -132,6 +132,27 @@ class TestONNXRuntime(unittest.TestCase):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_test(MyModel(), x)
# TODO: enable for opset 10 when ONNXRuntime version will be updated
@skipIfUnsupportedOpsetVersion([10])
def test_topk(self):
class MyModule(torch.nn.Module):
def forward(self, x):
return torch.topk(x, 3)
x = torch.arange(1., 6., requires_grad=True)
self.run_test(MyModule(), x)
@skipIfUnsupportedMinOpsetVersion(10)
def test_topk_script(self):
class MyModuleDynamic(torch.jit.ScriptModule):
@torch.jit.script_method
def forward(self, x, k):
return torch.topk(x, k)
x = torch.arange(1., 6., requires_grad=True)
k = torch.tensor(3)
self.run_test(MyModuleDynamic(), [x, k])
def test_layer_norm(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)

View File

@ -18,12 +18,14 @@ import torch.onnx.symbolic_opset9
# release on 04/24/19
@parse_args('v', 'i', 'i', 'i', 'i')
@parse_args('v', 'v', 'i', 'i', 'i')
def topk(g, self, k, dim, largest, sorted, out=None):
if out is not None:
_unimplemented("TopK", "Out parameter is not supported for topk")
if not largest:
_unimplemented("TopK", "Ascending TopK is not supported")
k = sym_help._maybe_get_const(k, 'i')
if not sym_help._is_value(k):
k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
from torch.onnx.symbolic_opset9 import unsqueeze
k = unsqueeze(g, k, 0)