mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
BoxWithNMSLimit support int batch_splits input (#47504)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47504 allow int type input of `batch_splits` Test Plan: ``` buck test caffe2/caffe2/python/operator_test:torch_integration_test -- test_box_with_nms_limits ``` Reviewed By: jackm321 Differential Revision: D24629522 fbshipit-source-id: 61cb132e792bddd8f9f1bca5b808f1a9131808f0
This commit is contained in:
parent
9d0c6e9469
commit
c19eb4ad73
|
|
@ -5,8 +5,9 @@
|
|||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
|
||||
const auto& tscores = Input(0);
|
||||
template <typename T>
|
||||
bool BoxWithNMSLimitOp<CPUContext>::DoRunWithType() {
|
||||
const auto& tscores = Input(0);
|
||||
const auto& tboxes = Input(1);
|
||||
|
||||
const int box_dim = rotated_ ? 5 : 4;
|
||||
|
|
@ -35,18 +36,19 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
|
|||
int num_boxes_classes = get_box_cls_index(num_classes - 1) + 1;
|
||||
CAFFE_ENFORCE_EQ(num_boxes_classes * box_dim, tboxes.size(1));
|
||||
|
||||
// Default value for batch_size and batch_splits
|
||||
int batch_size = 1;
|
||||
vector<float> batch_splits_default(1, tscores.size(0));
|
||||
const float* batch_splits_data = batch_splits_default.data();
|
||||
vector<T> batch_splits_default(1, tscores.size(0));
|
||||
const T* batch_splits_data = batch_splits_default.data();
|
||||
if (InputSize() > 2) {
|
||||
// tscores and tboxes have items from multiple images in a batch. Get the
|
||||
// corresponding batch splits from input.
|
||||
const auto& tbatch_splits = Input(2);
|
||||
CAFFE_ENFORCE_EQ(tbatch_splits.dim(), 1);
|
||||
batch_size = tbatch_splits.size(0);
|
||||
batch_splits_data = tbatch_splits.data<float>();
|
||||
batch_splits_data = tbatch_splits.data<T>();
|
||||
}
|
||||
Eigen::Map<const EArrXf> batch_splits(batch_splits_data, batch_size);
|
||||
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> batch_splits(batch_splits_data, batch_size);
|
||||
CAFFE_ENFORCE_EQ(batch_splits.sum(), N);
|
||||
|
||||
auto* out_scores = Output(0, {0}, at::dtype<float>());
|
||||
|
|
@ -65,7 +67,7 @@ bool BoxWithNMSLimitOp<CPUContext>::RunOnDevice() {
|
|||
vector<int> total_keep_per_batch(batch_size);
|
||||
int offset = 0;
|
||||
for (int b = 0; b < batch_splits.size(); ++b) {
|
||||
int num_boxes = batch_splits(b);
|
||||
int num_boxes = batch_splits[b];
|
||||
Eigen::Map<const ERArrXXf> scores(
|
||||
tscores.data<float>() + offset * tscores.size(1),
|
||||
num_boxes,
|
||||
|
|
|
|||
|
|
@ -60,7 +60,16 @@ class BoxWithNMSLimitOp final : public Operator<Context> {
|
|||
|
||||
~BoxWithNMSLimitOp() {}
|
||||
|
||||
bool RunOnDevice() override;
|
||||
bool RunOnDevice() override {
|
||||
if (InputSize() > 2) {
|
||||
return DispatchHelper<TensorTypes<int, float>>::call(this, Input(2));
|
||||
} else {
|
||||
return DoRunWithType<float>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType();
|
||||
|
||||
protected:
|
||||
// TEST.SCORE_THRESH
|
||||
|
|
|
|||
|
|
@ -180,6 +180,7 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
rotated=st.booleans(),
|
||||
angle_bound_on=st.booleans(),
|
||||
clip_angle_thresh=st.sampled_from([-1.0, 1.0]),
|
||||
batch_splits_dtype=st.sampled_from([torch.float32, torch.int32]),
|
||||
**hu.gcs_cpu_only
|
||||
)
|
||||
def test_box_with_nms_limits(
|
||||
|
|
@ -189,6 +190,7 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
rotated,
|
||||
angle_bound_on,
|
||||
clip_angle_thresh,
|
||||
batch_splits_dtype,
|
||||
gc,
|
||||
dc,
|
||||
):
|
||||
|
|
@ -250,7 +252,7 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
outputs = torch.ops._caffe2.BoxWithNMSLimit(
|
||||
torch.tensor(class_prob),
|
||||
torch.tensor(pred_bbox),
|
||||
torch.tensor(batch_splits),
|
||||
torch.tensor(batch_splits, dtype=batch_splits_dtype),
|
||||
score_thresh=float(score_thresh),
|
||||
nms=float(nms_thresh),
|
||||
detections_per_im=int(topk_per_image),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user