mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Export BatchBucketOneHot Caffe2 Operator to PyTorch
Summary: As titled. Test Plan: ``` buck test caffe2/caffe2/python/operator_test:torch_integration_test -- test_batch_bucket_one_hot_op ``` Reviewed By: yf225 Differential Revision: D23005981 fbshipit-source-id: 1daa8d3e7d6ad75e97e94964db95ccfb58541672
This commit is contained in:
parent
4afbf39737
commit
71dbfc79b3
|
|
@ -357,3 +357,8 @@ NO_GRADIENT(OneHot);
|
|||
NO_GRADIENT(SegmentOneHot);
|
||||
NO_GRADIENT(BucketBatchOneHot);
|
||||
} // namespace caffe2
|
||||
|
||||
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
|
||||
BatchBucketOneHot,
|
||||
"_caffe2::BatchBucketOneHot(Tensor data, Tensor lengths, Tensor boundaries) -> Tensor output",
|
||||
caffe2::BatchBucketOneHotOp<caffe2::CPUContext>);
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@
|
|||
#define CAFFE_OPERATORS_ONE_HOT_OPS_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/export_caffe2_op_to_c10.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(BatchBucketOneHot);
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
|
|
|
|||
|
|
@ -856,6 +856,25 @@ class TorchIntegration(hu.HypothesisTestCase):
|
|||
)
|
||||
torch.testing.assert_allclose(expected_output, actual_output.cpu())
|
||||
|
||||
def test_batch_bucket_one_hot_op(self):
|
||||
data = np.array([[2, 3], [4, 1], [2, 5]]).astype(np.float32)
|
||||
lengths = np.array([2, 3]).astype(np.int32)
|
||||
boundaries = np.array([0.1, 2.5, 1, 3.1, 4.5]).astype(np.float32)
|
||||
|
||||
def _batch_bucket_one_hot_ref(data, lengths, boundaries):
|
||||
ref_op = core.CreateOperator('BatchBucketOneHot', ["data", "lengths", "boundaries"], ["Y"])
|
||||
workspace.FeedBlob("data", data)
|
||||
workspace.FeedBlob("lengths", lengths)
|
||||
workspace.FeedBlob("boundaries", boundaries)
|
||||
workspace.RunOperatorOnce(ref_op)
|
||||
return workspace.FetchBlob("Y")
|
||||
|
||||
expected_output = _batch_bucket_one_hot_ref(data, lengths, boundaries)
|
||||
actual_output = torch.ops._caffe2.BatchBucketOneHot(
|
||||
torch.tensor(data), torch.Tensor(lengths).int(), torch.Tensor(boundaries)
|
||||
)
|
||||
torch.testing.assert_allclose(expected_output, actual_output.cpu())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user