fix bug of not using get_score_cls_index in BoxWithNMSLimitOp (#20868)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20868

When `input_boxes_include_bg_cls` is false (which means `input_scores_fg_cls_starting_id` is 0), It doesn't map the class index of score currectly when sorting and limiting the detections over all classes after nms.

Reviewed By: newstzpz

Differential Revision: D15472706

fbshipit-source-id: dc1e808b63ad09fb4bd95acf866771bb3fa92d69
This commit is contained in:
Yanghan Wang 2019-05-24 22:24:55 -07:00 committed by Facebook Github Bot
parent 2fb665a9df
commit 81e70ffa19
2 changed files with 15 additions and 10 deletions

View File

@ -131,7 +131,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
// Limit to max_per_image detections *over all classes*
if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
// merge all scores (represented by indices) together and sort
auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
auto get_all_scores_sorted = [&]() {
// flatten keeps[i][j] to [pair(i, keeps[i][j]), ...]
// first: class index (1 ~ keeps.size() - 1),
// second: values in keeps[first]
@ -139,19 +139,19 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
vector<KeepIndex> ret(total_keep_count);
int ret_idx = 0;
for (int i = 1; i < keeps.size(); i++) {
auto& cur_keep = keeps[i];
for (int j = 1; j < num_classes; j++) {
auto& cur_keep = keeps[j];
for (auto& ckv : cur_keep) {
ret[ret_idx++] = {i, ckv};
ret[ret_idx++] = {j, ckv};
}
}
std::sort(
ret.data(),
ret.data() + ret.size(),
[&scores](const KeepIndex& lhs, const KeepIndex& rhs) {
return scores(lhs.second, lhs.first) >
scores(rhs.second, rhs.first);
[this, &scores](const KeepIndex& lhs, const KeepIndex& rhs) {
return scores(lhs.second, this->get_score_cls_index(lhs.first)) >
scores(rhs.second, this->get_score_cls_index(rhs.first));
});
return ret;

View File

@ -122,6 +122,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
@given(
num_classes=st.integers(2, 10),
det_per_im=st.integers(1, 4),
cls_agnostic_bbox_reg=st.booleans(),
input_boxes_include_bg_cls=st.booleans(),
output_classes_include_bg_cls=st.booleans(),
@ -130,6 +131,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
def test_multiclass(
self,
num_classes,
det_per_im,
cls_agnostic_bbox_reg,
input_boxes_include_bg_cls,
output_classes_include_bg_cls,
@ -145,9 +147,12 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
if cls_agnostic_bbox_reg:
# only leave one class
boxes = boxes[:, :4]
# randomize un-used scores for background class
scores_bg_class_id = 0 if input_boxes_include_bg_cls else -1
scores[:, scores_bg_class_id] = np.random.rand(scores.shape[0]).astype(np.float32)
gt_centers = [(20, 20), (0, 0), (50, 50)]
gt_scores = [0.85, 0.7, 0.6]
gt_centers = [(20, 20), (0, 0), (50, 50)][:det_per_im]
gt_scores = [0.85, 0.7, 0.6][:det_per_im]
gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1)
# [1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
gt_classes = np.tile(
@ -164,7 +169,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
{
"score_thresh": 0.5,
"nms": 0.9,
"detections_per_im": 100,
"detections_per_im": (num_classes - 1) * det_per_im,
"cls_agnostic_bbox_reg": cls_agnostic_bbox_reg,
"input_boxes_include_bg_cls": input_boxes_include_bg_cls,
"output_classes_include_bg_cls": output_classes_include_bg_cls