diff --git a/caffe2/operators/one_hot_ops.cc b/caffe2/operators/one_hot_ops.cc index 0dd0aaf9104..c3eaf05db0e 100644 --- a/caffe2/operators/one_hot_ops.cc +++ b/caffe2/operators/one_hot_ops.cc @@ -357,3 +357,8 @@ NO_GRADIENT(OneHot); NO_GRADIENT(SegmentOneHot); NO_GRADIENT(BucketBatchOneHot); } // namespace caffe2 + +C10_EXPORT_CAFFE2_OP_TO_C10_CPU( + BatchBucketOneHot, + "_caffe2::BatchBucketOneHot(Tensor data, Tensor lengths, Tensor boundaries) -> Tensor output", + caffe2::BatchBucketOneHotOp); diff --git a/caffe2/operators/one_hot_ops.h b/caffe2/operators/one_hot_ops.h index d61c0d28580..b5dcd8953e5 100644 --- a/caffe2/operators/one_hot_ops.h +++ b/caffe2/operators/one_hot_ops.h @@ -2,10 +2,13 @@ #define CAFFE_OPERATORS_ONE_HOT_OPS_H_ #include "caffe2/core/context.h" +#include "caffe2/core/export_caffe2_op_to_c10.h" #include "caffe2/core/logging.h" #include "caffe2/core/operator.h" #include "caffe2/utils/math.h" +C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchBucketOneHot); + namespace caffe2 { template diff --git a/caffe2/python/operator_test/torch_integration_test.py b/caffe2/python/operator_test/torch_integration_test.py index f47ad24b259..5a5b3d8802b 100644 --- a/caffe2/python/operator_test/torch_integration_test.py +++ b/caffe2/python/operator_test/torch_integration_test.py @@ -856,6 +856,25 @@ class TorchIntegration(hu.HypothesisTestCase): ) torch.testing.assert_allclose(expected_output, actual_output.cpu()) + def test_batch_bucket_one_hot_op(self): + data = np.array([[2, 3], [4, 1], [2, 5]]).astype(np.float32) + lengths = np.array([2, 3]).astype(np.int32) + boundaries = np.array([0.1, 2.5, 1, 3.1, 4.5]).astype(np.float32) + + def _batch_bucket_one_hot_ref(data, lengths, boundaries): + ref_op = core.CreateOperator('BatchBucketOneHot', ["data", "lengths", "boundaries"], ["Y"]) + workspace.FeedBlob("data", data) + workspace.FeedBlob("lengths", lengths) + workspace.FeedBlob("boundaries", boundaries) + workspace.RunOperatorOnce(ref_op) + return workspace.FetchBlob("Y") + + expected_output = _batch_bucket_one_hot_ref(data, lengths, boundaries) + actual_output = torch.ops._caffe2.BatchBucketOneHot( + torch.tensor(data), torch.Tensor(lengths).int(), torch.Tensor(boundaries) + ) + torch.testing.assert_allclose(expected_output, actual_output.cpu()) + if __name__ == '__main__': unittest.main()