diff --git a/caffe2/operators/box_with_nms_limit_op.cc b/caffe2/operators/box_with_nms_limit_op.cc index c1890de4c60..8d438c7901d 100644 --- a/caffe2/operators/box_with_nms_limit_op.cc +++ b/caffe2/operators/box_with_nms_limit_op.cc @@ -131,7 +131,7 @@ bool BoxWithNMSLimitOp::RunOnDevice() { // Limit to max_per_image detections *over all classes* if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) { // merge all scores (represented by indices) together and sort - auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() { + auto get_all_scores_sorted = [&]() { // flatten keeps[i][j] to [pair(i, keeps[i][j]), ...] // first: class index (1 ~ keeps.size() - 1), // second: values in keeps[first] @@ -139,19 +139,19 @@ bool BoxWithNMSLimitOp::RunOnDevice() { vector ret(total_keep_count); int ret_idx = 0; - for (int i = 1; i < keeps.size(); i++) { - auto& cur_keep = keeps[i]; + for (int j = 1; j < num_classes; j++) { + auto& cur_keep = keeps[j]; for (auto& ckv : cur_keep) { - ret[ret_idx++] = {i, ckv}; + ret[ret_idx++] = {j, ckv}; } } std::sort( ret.data(), ret.data() + ret.size(), - [&scores](const KeepIndex& lhs, const KeepIndex& rhs) { - return scores(lhs.second, lhs.first) > - scores(rhs.second, rhs.first); + [this, &scores](const KeepIndex& lhs, const KeepIndex& rhs) { + return scores(lhs.second, this->get_score_cls_index(lhs.first)) > + scores(rhs.second, this->get_score_cls_index(rhs.first)); }); return ret; diff --git a/caffe2/python/operator_test/box_with_nms_limit_op_test.py b/caffe2/python/operator_test/box_with_nms_limit_op_test.py index cd869908953..5c6fd368d86 100644 --- a/caffe2/python/operator_test/box_with_nms_limit_op_test.py +++ b/caffe2/python/operator_test/box_with_nms_limit_op_test.py @@ -122,6 +122,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase): @given( num_classes=st.integers(2, 10), + det_per_im=st.integers(1, 4), cls_agnostic_bbox_reg=st.booleans(), input_boxes_include_bg_cls=st.booleans(), output_classes_include_bg_cls=st.booleans(), @@ -130,6 +131,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase): def test_multiclass( self, num_classes, + det_per_im, cls_agnostic_bbox_reg, input_boxes_include_bg_cls, output_classes_include_bg_cls, @@ -145,9 +147,12 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase): if cls_agnostic_bbox_reg: # only leave one class boxes = boxes[:, :4] + # randomize un-used scores for background class + scores_bg_class_id = 0 if input_boxes_include_bg_cls else -1 + scores[:, scores_bg_class_id] = np.random.rand(scores.shape[0]).astype(np.float32) - gt_centers = [(20, 20), (0, 0), (50, 50)] - gt_scores = [0.85, 0.7, 0.6] + gt_centers = [(20, 20), (0, 0), (50, 50)][:det_per_im] + gt_scores = [0.85, 0.7, 0.6][:det_per_im] gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1) # [1, 1, 1, 2, 2, 2, 3, 3, 3, ...] gt_classes = np.tile( @@ -164,7 +169,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase): { "score_thresh": 0.5, "nms": 0.9, - "detections_per_im": 100, + "detections_per_im": (num_classes - 1) * det_per_im, "cls_agnostic_bbox_reg": cls_agnostic_bbox_reg, "input_boxes_include_bg_cls": input_boxes_include_bg_cls, "output_classes_include_bg_cls": output_classes_include_bg_cls