pytorch/caffe2/operators/map_ops.cc
Kittipat Virochsiri 2b134c72e6 Add interface to provide blob types to shape&type inference (#9643)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9643

Current map interface assumes float data type, which is not always correct.

Reviewed By: kennyhorror

Differential Revision: D8455784

fbshipit-source-id: b94a31267760f7f97c15aa4b03008affc347fd10
2018-07-24 11:58:05 -07:00

81 lines
2.4 KiB
C++

#include "caffe2/operators/map_ops.h"
namespace caffe2 {
using MapType64To64 = MapTypeTraits<int64_t, int64_t>::MapType;
CAFFE_KNOWN_TYPE(MapType64To64);
using MapType64To32 = MapTypeTraits<int64_t, int32_t>::MapType;
CAFFE_KNOWN_TYPE(MapType64To32);
using MapType32To32 = MapTypeTraits<int32_t, int32_t>::MapType;
CAFFE_KNOWN_TYPE(MapType32To32);
using MapType32To64 = MapTypeTraits<int32_t, int64_t>::MapType;
CAFFE_KNOWN_TYPE(MapType32To64);
namespace {
REGISTER_BLOB_SERIALIZER(
TypeMeta::Id<MapType64To64>(),
MapSerializer<int64_t, int64_t>);
REGISTER_BLOB_SERIALIZER(
TypeMeta::Id<MapType64To32>(),
MapSerializer<int64_t, int32_t>);
REGISTER_BLOB_SERIALIZER(
TypeMeta::Id<MapType32To32>(),
MapSerializer<int32_t, int32_t>);
REGISTER_BLOB_SERIALIZER(
TypeMeta::Id<MapType32To64>(),
MapSerializer<int32_t, int64_t>);
REGISTER_BLOB_DESERIALIZER(
(std::unordered_map<int64_t, int64_t>),
MapDeserializer<int64_t, int64_t>);
REGISTER_BLOB_DESERIALIZER(
(std::unordered_map<int64_t, int32_t>),
MapDeserializer<int64_t, int32_t>);
REGISTER_BLOB_DESERIALIZER(
(std::unordered_map<int32_t, int32_t>),
MapDeserializer<int32_t, int32_t>);
REGISTER_BLOB_DESERIALIZER(
(std::unordered_map<int32_t, int64_t>),
MapDeserializer<int32_t, int64_t>);
REGISTER_CPU_OPERATOR(CreateMap, CreateMapOp<CPUContext>);
REGISTER_CPU_OPERATOR(KeyValueToMap, KeyValueToMapOp<CPUContext>);
REGISTER_CPU_OPERATOR(MapToKeyValue, MapToKeyValueOp<CPUContext>);
OPERATOR_SCHEMA(CreateMap)
.NumInputs(0)
.NumOutputs(1)
.SetDoc("Create an empty map blob")
.Arg("key_dtype", "Key's TensorProto::DataType (default INT32)")
.Arg("value_dtype", "Value's TensorProto::DataType (default INT32)")
.Output(0, "map blob", "Blob reference to the map")
.ScalarType(TensorProto_DataType_UNDEFINED);
OPERATOR_SCHEMA(KeyValueToMap)
.NumInputs(2)
.NumOutputs(1)
.SetDoc("Convert key and value blob pairs into a map blob")
.Input(0, "key blob", "Blob reference to the key")
.Input(1, "value blob", "Blob reference to the value")
.Output(0, "map blob", "Blob reference to the map");
OPERATOR_SCHEMA(MapToKeyValue)
.NumInputs(1)
.NumOutputs(2)
.SetDoc("Convert a map blob into key and value blob pairs")
.Input(0, "map blob", "Blob reference to the map")
.Output(0, "key blob", "Blob reference to the key")
.Output(1, "value blob", "Blob reference to the value");
}
} // namespace caffe2