mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
ONNX Export Topk with Dynamic k (+ add test cases)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21104 Differential Revision: D16061592 Pulled By: houseroad fbshipit-source-id: 855b310a138fdde9c25869ffe9f127189dc2eaf5
This commit is contained in:
parent
221af09ca7
commit
42c6ea5faa
|
|
@ -94,6 +94,12 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
||||||
auto* indices = Output(1);
|
auto* indices = Output(1);
|
||||||
auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
|
auto* flatten_indices = OutputSize() > 2 ? Output(2) : nullptr;
|
||||||
|
|
||||||
|
int64_t k = k_;
|
||||||
|
if(k == -1 && InputSize() == 2) {
|
||||||
|
k = Input(1).template data<int64_t>()[0];
|
||||||
|
}
|
||||||
|
CAFFE_ENFORCE(k >= 1, "k argument must be >= 1");
|
||||||
|
|
||||||
at::IntArrayRef input_dims = input.sizes();
|
at::IntArrayRef input_dims = input.sizes();
|
||||||
if (axis_ == -1) {
|
if (axis_ == -1) {
|
||||||
axis_ = input_dims.size() - 1;
|
axis_ = input_dims.size() - 1;
|
||||||
|
|
@ -102,7 +108,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
||||||
CAFFE_ENFORCE_LT(axis_, input_dims.size());
|
CAFFE_ENFORCE_LT(axis_, input_dims.size());
|
||||||
|
|
||||||
std::vector<int64_t> output_dims = input_dims.vec();
|
std::vector<int64_t> output_dims = input_dims.vec();
|
||||||
output_dims[axis_] = k_;
|
output_dims[axis_] = k;
|
||||||
values->Resize(output_dims);
|
values->Resize(output_dims);
|
||||||
indices->Resize(output_dims);
|
indices->Resize(output_dims);
|
||||||
if (flatten_indices != nullptr) {
|
if (flatten_indices != nullptr) {
|
||||||
|
|
@ -134,7 +140,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
||||||
int64_t(1),
|
int64_t(1),
|
||||||
std::multiplies<int64_t>());
|
std::multiplies<int64_t>());
|
||||||
const int64_t src_offset_stride = input_dims[axis_] * next_size;
|
const int64_t src_offset_stride = input_dims[axis_] * next_size;
|
||||||
const int64_t dst_offset_stride = k_ * next_size;
|
const int64_t dst_offset_stride = k * next_size;
|
||||||
int64_t src_offset = 0;
|
int64_t src_offset = 0;
|
||||||
int64_t dst_offset = 0;
|
int64_t dst_offset = 0;
|
||||||
for (int64_t i = 0; i < prev_size; ++i) {
|
for (int64_t i = 0; i < prev_size; ++i) {
|
||||||
|
|
@ -142,7 +148,7 @@ bool TopKOp<T, Context>::RunOnDevice() {
|
||||||
GetTopK(
|
GetTopK(
|
||||||
input_data,
|
input_data,
|
||||||
input_dims[axis_],
|
input_dims[axis_],
|
||||||
k_,
|
k,
|
||||||
src_offset + j,
|
src_offset + j,
|
||||||
dst_offset + j,
|
dst_offset + j,
|
||||||
next_size,
|
next_size,
|
||||||
|
|
@ -209,7 +215,7 @@ REGISTER_CPU_OPERATOR(TopK, TopKOp<float, CPUContext>);
|
||||||
REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
|
REGISTER_CPU_OPERATOR(TopKGradient, TopKGradientOp<float, CPUContext>);
|
||||||
|
|
||||||
OPERATOR_SCHEMA(TopK)
|
OPERATOR_SCHEMA(TopK)
|
||||||
.NumInputs(1)
|
.NumInputs(1, 2)
|
||||||
.NumOutputs(2, 3)
|
.NumOutputs(2, 3)
|
||||||
.TensorInferenceFunction([](const OperatorDef& def,
|
.TensorInferenceFunction([](const OperatorDef& def,
|
||||||
const vector<TensorShape>& in) {
|
const vector<TensorShape>& in) {
|
||||||
|
|
@ -235,7 +241,9 @@ OPERATOR_SCHEMA(TopK)
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.SetDoc(R"DOC(
|
.SetDoc(R"DOC(
|
||||||
Retrieve the top-K elements of the last dimension. Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$ and integer argument `k`, return up to three outputs:
|
Retrieve the top-K elements of the last dimension.
|
||||||
|
Given an input tensor of shape $(a_1, a_2, ..., a_n, r)$. `k` can be passed as an integer argument or a 1D tensor containing a single integer.
|
||||||
|
Returns up to three outputs:
|
||||||
|
|
||||||
1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension
|
1. Value tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the values of the top k elements along the last dimension
|
||||||
2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor).
|
2. Index tensor of shape $(a_1, a_2, ..., a_n, k)$ which contains the indices of the top k elements (original indices from the input tensor).
|
||||||
|
|
@ -324,6 +332,10 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
|
||||||
0,
|
0,
|
||||||
"X",
|
"X",
|
||||||
"(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
|
"(*Tensor`<float>`*): input tensor of shape $(a_1, a_2, ..., a_n, r)$")
|
||||||
|
.Input(
|
||||||
|
1,
|
||||||
|
"k",
|
||||||
|
"(*int*): number of top elements to retrieve")
|
||||||
.Output(
|
.Output(
|
||||||
0,
|
0,
|
||||||
"Values",
|
"Values",
|
||||||
|
|
@ -335,8 +347,7 @@ Flattened_indices: [ 1 0 3 4 8 7 10 11 13 14 17 16 20 18 23 22 26 25]
|
||||||
.Output(
|
.Output(
|
||||||
2,
|
2,
|
||||||
"Flattened_indices",
|
"Flattened_indices",
|
||||||
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`")
|
"(*Tensor`<int>`*): tensor of indices of shape $(a_1 * a_2 * ... * a_n * k,)$; indices values refer to each element's index in the flattened input tensor `X`");
|
||||||
.Arg("k", "(*int*): number of top elements to retrieve");
|
|
||||||
|
|
||||||
OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
|
OPERATOR_SCHEMA(TopKGradient).NumInputs(3).NumOutputs(1);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ class TopKOp : public Operator<Context> {
|
||||||
: Operator<Context>(std::forward<Args>(args)...),
|
: Operator<Context>(std::forward<Args>(args)...),
|
||||||
OP_SINGLE_ARG(int, "k", k_, -1),
|
OP_SINGLE_ARG(int, "k", k_, -1),
|
||||||
OP_SINGLE_ARG(int, "axis", axis_, -1) {
|
OP_SINGLE_ARG(int, "axis", axis_, -1) {
|
||||||
CAFFE_ENFORCE(k_ >= 1, "k argument must be >= 1");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
~TopKOp() {}
|
~TopKOp() {}
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,23 @@ class TestONNXOpset(TestCase):
|
||||||
x = torch.arange(1., 6., requires_grad=True)
|
x = torch.arange(1., 6., requires_grad=True)
|
||||||
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
|
check_onnx_opsets_operator(MyModule(), x, ops, opset_versions=[9, 10])
|
||||||
|
|
||||||
|
# test with dynamic k
|
||||||
|
class MyModuleDynamic(torch.jit.ScriptModule):
|
||||||
|
@torch.jit.script_method
|
||||||
|
def forward(self, input, k):
|
||||||
|
return torch.topk(input, k)
|
||||||
|
|
||||||
|
ops_10 = [{"op_name" : "Unsqueeze", "attributes" : [{"name" : "axes", "ints" : [0], "type" : 7}]},
|
||||||
|
{"op_name" : "TopK", "attributes" : [{"name" : "axis", "i" : -1, "type" : 2}]}]
|
||||||
|
ops = {10 : ops_10}
|
||||||
|
x = torch.arange(1., 6., requires_grad=True)
|
||||||
|
k = torch.tensor(3)
|
||||||
|
module = MyModuleDynamic()
|
||||||
|
example_output = module(x, k)
|
||||||
|
check_onnx_opsets_operator(module, [x, k], ops,
|
||||||
|
opset_versions=[10],
|
||||||
|
example_outputs=example_output)
|
||||||
|
|
||||||
def test_maxpool(self):
|
def test_maxpool(self):
|
||||||
module = torch.nn.MaxPool1d(2, stride=1)
|
module = torch.nn.MaxPool1d(2, stride=1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1724,7 +1724,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||||
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
|
self.run_model_test(NestedTupleModel(), train=False, input=(x, y), batch_size=BATCH_SIZE,
|
||||||
example_outputs=x + y[0] + y[1][0] + y[1][1])
|
example_outputs=x + y[0] + y[1][0] + y[1][1])
|
||||||
|
|
||||||
@skipIfUnsupportedOpsetVersion([10])
|
|
||||||
def test_topk(self):
|
def test_topk(self):
|
||||||
class TopKModel(torch.nn.Module):
|
class TopKModel(torch.nn.Module):
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -1733,7 +1732,6 @@ class TestCaffe2Backend_opset9(unittest.TestCase):
|
||||||
x = torch.arange(1., 6.)
|
x = torch.arange(1., 6.)
|
||||||
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
self.run_model_test(TopKModel(), train=False, input=x, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
@skipIfUnsupportedOpsetVersion([10])
|
|
||||||
def test_topk_script(self):
|
def test_topk_script(self):
|
||||||
class TopKModel(torch.jit.ScriptModule):
|
class TopKModel(torch.jit.ScriptModule):
|
||||||
@torch.jit.script_method
|
@torch.jit.script_method
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion
|
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion
|
||||||
|
|
||||||
|
|
||||||
class TestONNXRuntime(unittest.TestCase):
|
class TestONNXRuntime(unittest.TestCase):
|
||||||
|
|
@ -17,7 +17,7 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
opset_version = _export_onnx_opset_version
|
opset_version = _export_onnx_opset_version
|
||||||
|
|
||||||
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
|
def run_test(self, model, inputs, rtol=1e-05, atol=1e-08):
|
||||||
outputs = model(inputs)
|
outputs = model(inputs) if isinstance(inputs, torch.Tensor) else model(*inputs)
|
||||||
|
|
||||||
# export the model to ONNX
|
# export the model to ONNX
|
||||||
f = io.BytesIO()
|
f = io.BytesIO()
|
||||||
|
|
@ -132,6 +132,27 @@ class TestONNXRuntime(unittest.TestCase):
|
||||||
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
x = torch.randn(1, 2, 3, 4, requires_grad=True)
|
||||||
self.run_test(MyModel(), x)
|
self.run_test(MyModel(), x)
|
||||||
|
|
||||||
|
# TODO: enable for opset 10 when ONNXRuntime version will be updated
|
||||||
|
@skipIfUnsupportedOpsetVersion([10])
|
||||||
|
def test_topk(self):
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.topk(x, 3)
|
||||||
|
|
||||||
|
x = torch.arange(1., 6., requires_grad=True)
|
||||||
|
self.run_test(MyModule(), x)
|
||||||
|
|
||||||
|
@skipIfUnsupportedMinOpsetVersion(10)
|
||||||
|
def test_topk_script(self):
|
||||||
|
class MyModuleDynamic(torch.jit.ScriptModule):
|
||||||
|
@torch.jit.script_method
|
||||||
|
def forward(self, x, k):
|
||||||
|
return torch.topk(x, k)
|
||||||
|
|
||||||
|
x = torch.arange(1., 6., requires_grad=True)
|
||||||
|
k = torch.tensor(3)
|
||||||
|
self.run_test(MyModuleDynamic(), [x, k])
|
||||||
|
|
||||||
def test_layer_norm(self):
|
def test_layer_norm(self):
|
||||||
model = torch.nn.LayerNorm([10, 10])
|
model = torch.nn.LayerNorm([10, 10])
|
||||||
x = torch.randn(20, 5, 10, 10)
|
x = torch.randn(20, 5, 10, 10)
|
||||||
|
|
|
||||||
|
|
@ -18,13 +18,15 @@ import torch.onnx.symbolic_opset9
|
||||||
# release on 04/24/19
|
# release on 04/24/19
|
||||||
|
|
||||||
|
|
||||||
@parse_args('v', 'i', 'i', 'i', 'i')
|
@parse_args('v', 'v', 'i', 'i', 'i')
|
||||||
def topk(g, self, k, dim, largest, sorted, out=None):
|
def topk(g, self, k, dim, largest, sorted, out=None):
|
||||||
if out is not None:
|
if out is not None:
|
||||||
_unimplemented("TopK", "Out parameter is not supported for topk")
|
_unimplemented("TopK", "Out parameter is not supported for topk")
|
||||||
if not largest:
|
if not largest:
|
||||||
_unimplemented("TopK", "Ascending TopK is not supported")
|
_unimplemented("TopK", "Ascending TopK is not supported")
|
||||||
k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
|
k = sym_help._maybe_get_const(k, 'i')
|
||||||
|
if not sym_help._is_value(k):
|
||||||
|
k = g.op("Constant", value_t=torch.tensor(k, dtype=torch.int64))
|
||||||
from torch.onnx.symbolic_opset9 import unsqueeze
|
from torch.onnx.symbolic_opset9 import unsqueeze
|
||||||
k = unsqueeze(g, k, 0)
|
k = unsqueeze(g, k, 0)
|
||||||
return g.op("TopK", self, k, axis_i=dim, outputs=2)
|
return g.op("TopK", self, k, axis_i=dim, outputs=2)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user