mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH` All changes but the ones to `.clang-tidy` are generated using following script: ``` for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008 Reviewed By: driazati, r-barnes Differential Revision: D29838584 Pulled By: malfet fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
92 lines
3.0 KiB
C++
92 lines
3.0 KiB
C++
#include "caffe2/onnx/offline_tensor.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
#ifndef C10_MOBILE
|
|
namespace {
|
|
// These constants need to be aligned with onnxifi.h
|
|
constexpr uint64_t kONNXIFI_DATATYPE_FLOAT16 = 10;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_FLOAT32 = 1;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_UINT8 = 2;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_INT32 = 6;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_INT8 = 3;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_INT64 = 7;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_INT16 = 5;
|
|
constexpr uint64_t kONNXIFI_DATATYPE_UINT16 = 4;
|
|
} // namespace
|
|
|
|
CAFFE_KNOWN_TYPE(OfflineTensor);
|
|
|
|
bool OfflineTensorShapeFunctions::IsSameMetaType(TypeIdentifier id) {
|
|
return id == TypeMeta::Id<OfflineTensor>();
|
|
}
|
|
|
|
TypeIdentifier OfflineTensorShapeFunctions::GetTypeMetaId() {
|
|
return TypeMeta::Id<OfflineTensor>();
|
|
}
|
|
|
|
TypeMeta OfflineTensorShapeFunctions::GetExternalTensorType(const void* c) {
|
|
const OfflineTensor* offline_tensor =
|
|
reinterpret_cast<const OfflineTensor*>(c);
|
|
|
|
return offline_tensor->shape_tensor.dtype();
|
|
}
|
|
|
|
vector<int64_t> OfflineTensorShapeFunctions::GetExternalTensorInfo(
|
|
const void* c,
|
|
size_t* capacity,
|
|
DeviceOption* device) {
|
|
const OfflineTensor* offline_tensor =
|
|
reinterpret_cast<const OfflineTensor*>(c);
|
|
return GetTensorInfo(&(offline_tensor->shape_tensor), capacity, device);
|
|
}
|
|
|
|
void OfflineTensorShapeFunctions::SetupExternalTensorDescriptor(
|
|
const Blob* blob,
|
|
std::vector<std::vector<uint64_t>>* shapes,
|
|
std::vector<std::vector<float>>* /* unused */,
|
|
std::vector<std::vector<int32_t>>* /* unused */,
|
|
ExternalTensorDescriptor* desc) {
|
|
const auto& offline_tensor = blob->template Get<OfflineTensor>();
|
|
const Tensor& shape_tensor = offline_tensor.shape_tensor;
|
|
|
|
if (shape_tensor.template IsType<float>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_FLOAT32;
|
|
} else if (shape_tensor.template IsType<int32_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_INT32;
|
|
} else if (shape_tensor.template IsType<int8_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_INT8;
|
|
} else if (shape_tensor.template IsType<uint8_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_UINT8;
|
|
} else if (shape_tensor.template IsType<int64_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_INT64;
|
|
} else if (shape_tensor.template IsType<int16_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_INT16;
|
|
} else if (shape_tensor.template IsType<c10::Half>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_FLOAT16;
|
|
} else if (shape_tensor.template IsType<uint16_t>()) {
|
|
desc->dataType = kONNXIFI_DATATYPE_UINT16;
|
|
} else {
|
|
CAFFE_THROW("Unsupported tensor type: ", shape_tensor.dtype().name());
|
|
}
|
|
desc->buffer = 0;
|
|
|
|
desc->quantizationParams = 0;
|
|
desc->quantizationAxis = 0;
|
|
|
|
// Set up dim and shape
|
|
const auto shape = shape_tensor.sizes();
|
|
desc->dimensions = shape.size();
|
|
shapes->emplace_back(shape.cbegin(), shape.cend());
|
|
desc->shape = shapes->back().data();
|
|
|
|
// It is an offline tensor
|
|
desc->isOffline = 1;
|
|
}
|
|
|
|
REGISTER_EXTERNAL_TENSOR_FUNCTIONS(
|
|
(TypeMeta::Id<OfflineTensor>()),
|
|
OfflineTensorShapeFunctions);
|
|
#endif
|
|
} // namespace caffe2
|