Export BatchBucketOneHot Caffe2 Operator to PyTorch

Summary: As titled.

Test Plan:
```
buck test caffe2/caffe2/python/operator_test:torch_integration_test -- test_batch_bucket_one_hot_op
```

Reviewed By: yf225

Differential Revision: D23005981

fbshipit-source-id: 1daa8d3e7d6ad75e97e94964db95ccfb58541672
This commit is contained in:
Edson Romero 2020-08-11 13:20:31 -07:00 committed by Facebook GitHub Bot
parent 4afbf39737
commit 71dbfc79b3
3 changed files with 27 additions and 0 deletions

View File

@ -357,3 +357,8 @@ NO_GRADIENT(OneHot);
NO_GRADIENT(SegmentOneHot);
NO_GRADIENT(BucketBatchOneHot);
} // namespace caffe2
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
BatchBucketOneHot,
"_caffe2::BatchBucketOneHot(Tensor data, Tensor lengths, Tensor boundaries) -> Tensor output",
caffe2::BatchBucketOneHotOp<caffe2::CPUContext>);

View File

@ -2,10 +2,13 @@
#define CAFFE_OPERATORS_ONE_HOT_OPS_H_
#include "caffe2/core/context.h"
#include "caffe2/core/export_caffe2_op_to_c10.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/math.h"
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchBucketOneHot);
namespace caffe2 {
template <class Context>

View File

@ -856,6 +856,25 @@ class TorchIntegration(hu.HypothesisTestCase):
)
torch.testing.assert_allclose(expected_output, actual_output.cpu())
def test_batch_bucket_one_hot_op(self):
data = np.array([[2, 3], [4, 1], [2, 5]]).astype(np.float32)
lengths = np.array([2, 3]).astype(np.int32)
boundaries = np.array([0.1, 2.5, 1, 3.1, 4.5]).astype(np.float32)
def _batch_bucket_one_hot_ref(data, lengths, boundaries):
ref_op = core.CreateOperator('BatchBucketOneHot', ["data", "lengths", "boundaries"], ["Y"])
workspace.FeedBlob("data", data)
workspace.FeedBlob("lengths", lengths)
workspace.FeedBlob("boundaries", boundaries)
workspace.RunOperatorOnce(ref_op)
return workspace.FetchBlob("Y")
expected_output = _batch_bucket_one_hot_ref(data, lengths, boundaries)
actual_output = torch.ops._caffe2.BatchBucketOneHot(
torch.tensor(data), torch.Tensor(lengths).int(), torch.Tensor(boundaries)
)
torch.testing.assert_allclose(expected_output, actual_output.cpu())
if __name__ == '__main__':
unittest.main()