mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
fix bug of not using get_score_cls_index in BoxWithNMSLimitOp (#20868)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/20868 When `input_boxes_include_bg_cls` is false (which means `input_scores_fg_cls_starting_id` is 0), It doesn't map the class index of score currectly when sorting and limiting the detections over all classes after nms. Reviewed By: newstzpz Differential Revision: D15472706 fbshipit-source-id: dc1e808b63ad09fb4bd95acf866771bb3fa92d69
This commit is contained in:
parent
2fb665a9df
commit
81e70ffa19
|
|
@ -131,7 +131,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
|
|||
// Limit to max_per_image detections *over all classes*
|
||||
if (detections_per_im_ > 0 && total_keep_count > detections_per_im_) {
|
||||
// merge all scores (represented by indices) together and sort
|
||||
auto get_all_scores_sorted = [&scores, &keeps, total_keep_count]() {
|
||||
auto get_all_scores_sorted = [&]() {
|
||||
// flatten keeps[i][j] to [pair(i, keeps[i][j]), ...]
|
||||
// first: class index (1 ~ keeps.size() - 1),
|
||||
// second: values in keeps[first]
|
||||
|
|
@ -139,19 +139,19 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
|
|||
vector<KeepIndex> ret(total_keep_count);
|
||||
|
||||
int ret_idx = 0;
|
||||
for (int i = 1; i < keeps.size(); i++) {
|
||||
auto& cur_keep = keeps[i];
|
||||
for (int j = 1; j < num_classes; j++) {
|
||||
auto& cur_keep = keeps[j];
|
||||
for (auto& ckv : cur_keep) {
|
||||
ret[ret_idx++] = {i, ckv};
|
||||
ret[ret_idx++] = {j, ckv};
|
||||
}
|
||||
}
|
||||
|
||||
std::sort(
|
||||
ret.data(),
|
||||
ret.data() + ret.size(),
|
||||
[&scores](const KeepIndex& lhs, const KeepIndex& rhs) {
|
||||
return scores(lhs.second, lhs.first) >
|
||||
scores(rhs.second, rhs.first);
|
||||
[this, &scores](const KeepIndex& lhs, const KeepIndex& rhs) {
|
||||
return scores(lhs.second, this->get_score_cls_index(lhs.first)) >
|
||||
scores(rhs.second, this->get_score_cls_index(rhs.first));
|
||||
});
|
||||
|
||||
return ret;
|
||||
|
|
|
|||
|
|
@ -122,6 +122,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
|
|||
|
||||
@given(
|
||||
num_classes=st.integers(2, 10),
|
||||
det_per_im=st.integers(1, 4),
|
||||
cls_agnostic_bbox_reg=st.booleans(),
|
||||
input_boxes_include_bg_cls=st.booleans(),
|
||||
output_classes_include_bg_cls=st.booleans(),
|
||||
|
|
@ -130,6 +131,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
|
|||
def test_multiclass(
|
||||
self,
|
||||
num_classes,
|
||||
det_per_im,
|
||||
cls_agnostic_bbox_reg,
|
||||
input_boxes_include_bg_cls,
|
||||
output_classes_include_bg_cls,
|
||||
|
|
@ -145,9 +147,12 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
|
|||
if cls_agnostic_bbox_reg:
|
||||
# only leave one class
|
||||
boxes = boxes[:, :4]
|
||||
# randomize un-used scores for background class
|
||||
scores_bg_class_id = 0 if input_boxes_include_bg_cls else -1
|
||||
scores[:, scores_bg_class_id] = np.random.rand(scores.shape[0]).astype(np.float32)
|
||||
|
||||
gt_centers = [(20, 20), (0, 0), (50, 50)]
|
||||
gt_scores = [0.85, 0.7, 0.6]
|
||||
gt_centers = [(20, 20), (0, 0), (50, 50)][:det_per_im]
|
||||
gt_scores = [0.85, 0.7, 0.6][:det_per_im]
|
||||
gt_boxes, gt_scores = gen_multiple_boxes(gt_centers, gt_scores, 1, 1)
|
||||
# [1, 1, 1, 2, 2, 2, 3, 3, 3, ...]
|
||||
gt_classes = np.tile(
|
||||
|
|
@ -164,7 +169,7 @@ class TestBoxWithNMSLimitOp(serial.SerializedTestCase):
|
|||
{
|
||||
"score_thresh": 0.5,
|
||||
"nms": 0.9,
|
||||
"detections_per_im": 100,
|
||||
"detections_per_im": (num_classes - 1) * det_per_im,
|
||||
"cls_agnostic_bbox_reg": cls_agnostic_bbox_reg,
|
||||
"input_boxes_include_bg_cls": input_boxes_include_bg_cls,
|
||||
"output_classes_include_bg_cls": output_classes_include_bg_cls
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user