mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
ONNX Export Topk with Dynamic k (+ add test cases)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21104 Differential Revision: D16061592 Pulled By: houseroad fbshipit-source-id: 855b310a138fdde9c25869ffe9f127189dc2eaf5
This commit is contained in:
parent
221af09ca7
commit
42c6ea5faa
|
|
@ -94,6 +94,12 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
|||
auto* indices = Output(1);
|
||||
auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
|
||||
|
||||
int64_t k = k_;
|
||||
if(k == -1 && InputSize() == 2) {
|
||||
k = Input(1).template data<int64_t>()[0];
|
||||
}
|
||||
CAFFE_ENFORCE(k >= 1, "k argument must be >= 1");
|
||||
|
||||
at::IntArrayRef input_dims = input.sizes();
|
||||
if (axis_ == -1) {
|
||||
axis_ = input_dims.size() - 1;
|
||||
|
|
@ -102,7 +108,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
|||
CAFFE_ENFORCE_LT(axis_, input_dims.size());
|
||||
|
||||
std::vector<int64_t> output_dims = input_dims.vec();
|
||||
output_dims[axis_] = k_;
|
||||
output_dims[axis_] = k;
|
||||
values->Resize(output_dims);
|
||||
indices->Resize(output_dims);
|
||||
if (flatten_indices != nullptr) {
|
||||
|
|
@ -134,7 +140,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
|||
int64_t(1),
|
||||
std::multiplies<int64_t>());
|
||||
const int64_t src_offset_stride = input_dims[axis_] * next_size;
|
||||
const int64_t dst_offset_stride = k_ * next_size;
|
||||
const int64_t dst_offset_stride = k * next_size;
|
||||
int64_t src_offset = 0;
|
||||
int64_t dst_offset = 0;
|
||||
for (int64_t i = 0; i < prev_size; ++i) {
|
||||
|
|
@ -142,7 +148,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
|||
GetTopK(
|
||||
input_data,
|
||||
input_dims[axis_],
|
||||
k_,
|
||||
k,
|
||||
src_offset + j,
|
||||
dst_offset + j,
|
||||
next_size,
|
||||
|
|
@ -209,7 +215,7 @@ REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
|
|||
REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(TopK)
|
||||
.NumInputs(1)
|
||||
.NumInputs(1, 2)
|
||||
.NumOutputs(2, 3)
|
||||
.TensorInferenceFunction([](const OperatorDef& def,
|
||||
const vector<TensorShape>& in) {
|
||||
|
|
@ -235,7 +241,9 @@ OPERATOR_SCHEMA(TopK)
|
|||
return out;
|
||||
})
|
||||
.SetDoc(R"DOC(
|
||||
Retrieve the top-K elements of the last dimension. Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$ and integer argument `k`, return up to three outputs:
|
||||
Retrieve the top-K elements of the last dimension.
|
||||
Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$. `k` can be passed as an integer argument or a 1D tensor containing a single integer.
|
||||
Returns up to three outputs:
|
||||
|
||||
1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension
|
||||
2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor).
|
||||
|
|
@ -324,6 +332,10 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
|
|||
0,
|
||||
"X",
|
||||
"(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
|
||||
.Input(
|
||||
1,
|
||||
"k",
|
||||
"(*int*): number of top elements to retrieve")
|
||||
.Output(
|
||||
0,
|
||||
"Values",
|
||||
|
|
@ -335,8 +347,7 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
|
|||
.Output(
|
||||
2,
|
||||
"Flattened_indices",
|
||||
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`")
|
||||
.Arg("k", "(*int*): number of top elements to retrieve");
|
||||
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`");
|
||||
|
||||
OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ class TopKOp : public Operator<Context> {
|
|||
: Operator<Context>(std::forward<Args>(args)...),
|
||||
OP_SINGLE_ARG(int, "k", k_, -1),
|
||||
OP_SINGLE_ARG(int, "axis", axis_, -1) {
|
||||
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
|
||||
}
|
||||
|
||||
~TopKOp() {}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,23 @@ class TestONNXOpset(TestCase):
|
|||
x = torch.arange(1., 6., requires_grad=True)
|
||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
|
||||
|
||||
# test with dynamic k
|
||||
class MyModuleDynamic(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, input, k):
|
||||
return torch.topk(input, k)
|
||||
|
||||
ops_10 = [{"op_name" : "Unsqueeze", "attributes" : [{"name" : "axes", "ints" : [0], "type" : 7}]},
|
||||
{"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2}]}]
|
||||
ops = {10 : ops_10}
|
||||
x = torch.arange(1., 6., requires_grad=True)
|
||||
k = torch.tensor(3)
|
||||
module = MyModuleDynamic()
|
||||
example_output = module(x, k)
|
||||
check_onnx_opsets_operator(module, [x, k], ops,
|
||||
opset_versions=[10],
|
||||
example_outputs=example_output)
|
||||
|
||||
def test_maxpool(self):
|
||||
module = torch.nn.MaxPool1d(2, stride=1)
|
||||
|
||||
|
|
|
|||
|
|
@ -1724,7 +1724,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
|||
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
|
||||
example_outputs=x + y[0] + y[1][0] + y[1][1])
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
def test_topk(self):
|
||||
class TopKModel(torch.nn.Module):
|
||||
def forward(self, input):
|
||||
|
|
@ -1733,7 +1732,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
|||
x = torch.arange(1., 6.)
|
||||
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
def test_topk_script(self):
|
||||
class TopKModel(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import numpy as np
|
||||
import io
|
||||
|
||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
|
||||
|
||||
|
||||
class TestONNXRuntime(unittest.TestCase):
|
||||
|
|
@ -17,7 +17,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
opset_version = _export_onnx_opset_version
|
||||
|
||||
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
|
||||
outputs = model(inputs)
|
||||
outputs = model(inputs) if isinstance(inputs, torch.Tensor) else model(*inputs)
|
||||
|
||||
# export the model to ONNX
|
||||
f = io.BytesIO()
|
||||
|
|
@ -132,6 +132,27 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||
self.run_test(MyModel(), x)
|
||||
|
||||
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
||||
@skipIfUnsupportedOpsetVersion([10])
|
||||
def test_topk(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.topk(x, 3)
|
||||
|
||||
x = torch.arange(1., 6., requires_grad=True)
|
||||
self.run_test(MyModule(), x)
|
||||
|
||||
@skipIfUnsupportedMinOpsetVersion(10)
|
||||
def test_topk_script(self):
|
||||
class MyModuleDynamic(torch.jit.ScriptModule):
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, k):
|
||||
return torch.topk(x, k)
|
||||
|
||||
x = torch.arange(1., 6., requires_grad=True)
|
||||
k = torch.tensor(3)
|
||||
self.run_test(MyModuleDynamic(), [x, k])
|
||||
|
||||
def test_layer_norm(self):
|
||||
model = torch.nn.LayerNorm([10, 10])
|
||||
x = torch.randn(20, 5, 10, 10)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,14 @@ import torch.onnx.symbolic_opset9
|
|||
# release on 04/24/19
|
||||
|
||||
|
||||
@parse_args('v', 'i', 'i', 'i', 'i')
|
||||
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||
def topk(g, self, k, dim, largest, sorted, out=None):
|
||||
if out is not None:
|
||||
_unimplemented("TopK", "Out parameter is not supported for topk")
|
||||
if not largest:
|
||||
_unimplemented("TopK", "Ascending TopK is not supported")
|
||||
k = sym_help._maybe_get_const(k, 'i')
|
||||
if not sym_help._is_value(k):
|
||||
k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
|
||||
from torch.onnx.symbolic_opset9 import unsqueeze
|
||||
k = unsqueeze(g, k, 0)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user