BoxWithNMSLimit support int batch_splits input (#47504)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47504

allow int type input of `batch_splits`

Test Plan:
```
buck test caffe2/caffe2/python/operator_test:torch_integration_test -- test_box_with_nms_limits
```

Reviewed By: jackm321

Differential Revision: D24629522

fbshipit-source-id: 61cb132e792bddd8f9f1bca5b808f1a9131808f0
This commit is contained in:
Shiyan Deng 2020-11-06 23:32:17 -08:00 committed by Facebook GitHub Bot
parent 9d0c6e9469
commit c19eb4ad73
3 changed files with 22 additions and 9 deletions

View File

@ -5,8 +5,9 @@
namespace caffe2 {
template <>
bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
const auto& tscores = Input(0);
template <typename T>
bool BoxWithNMSLimitOp<CPUContext>::DoRunWithType() {
const auto& tscores = Input(0);
const auto& tboxes = Input(1);
const int box_dim = rotated_ ? 5 : 4;
@ -35,18 +36,19 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1;
CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1));
// Default value for batch_size and batch_splits
int batch_size = 1;
vector<float> batch_splits_default(1, tscores.size(0));
const float* batch_splits_data = batch_splits_default.data();
vector<T> batch_splits_default(1, tscores.size(0));
const T* batch_splits_data = batch_splits_default.data();
if (InputSize() > 2) {
// tscores and tboxes have items from multiple images in a batch. Get the
// corresponding batch splits from input.
const auto& tbatch_splits = Input(2);
CAFFE_ENFORCE_EQ(tbatch_splits.dim(), 1);
batch_size = tbatch_splits.size(0);
batch_splits_data = tbatch_splits.data<float>();
batch_splits_data = tbatch_splits.data<T>();
}
Eigen::Map<const EArrXf> batch_splits(batch_splits_data, batch_size);
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> batch_splits(batch_splits_data, batch_size);
CAFFE_ENFORCE_EQ(batch_splits.sum(), N);
auto* out_scores = Output(0, {0}, at::dtype<float>());
@ -65,7 +67,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
vector<int> total_keep_per_batch(batch_size);
int offset = 0;
for (int b = 0; b < batch_splits.size(); ++b) {
int num_boxes = batch_splits(b);
int num_boxes = batch_splits[b];
Eigen::Map<const ERArrXXf> scores(
tscores.data<float>() + offset * tscores.size(1),
num_boxes,

View File

@ -60,7 +60,16 @@ class BoxWithNMSLimitOp final : public Operator<Context> {
~BoxWithNMSLimitOp() {}
bool RunOnDevice() override;
bool RunOnDevice() override {
if (InputSize() > 2) {
return DispatchHelper<TensorTypes<int, float>>::call(this, Input(2));
} else {
return DoRunWithType<float>();
}
}
template <typename T>
bool DoRunWithType();
protected:
// TEST.SCORE_THRESH

View File

@ -180,6 +180,7 @@ class TorchIntegration(hu.HypothesisTestCase):
rotated=st.booleans(),
angle_bound_on=st.booleans(),
clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
batch_splits_dtype=st.sampled_from([torch.float32, torch.int32]),
**hu.gcs_cpu_only
)
def test_box_with_nms_limits(
@ -189,6 +190,7 @@ class TorchIntegration(hu.HypothesisTestCase):
rotated,
angle_bound_on,
clip_angle_thresh,
batch_splits_dtype,
gc,
dc,
):
@ -250,7 +252,7 @@ class TorchIntegration(hu.HypothesisTestCase):
outputs = torch.ops._caffe2.BoxWithNMSLimit(
torch.tensor(class_prob),
torch.tensor(pred_bbox),
torch.tensor(batch_splits),
torch.tensor(batch_splits, dtype=batch_splits_dtype),
score_thresh=float(score_thresh),
nms=float(nms_thresh),
detections_per_im=int(topk_per_image),