mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
54 lines
1.6 KiB
C++
54 lines
1.6 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/Storage.h>
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/tensor.h"
|
|
|
|
namespace caffe2 {
|
|
|
|
#ifndef C10_MOBILE
|
|
struct TORCH_API OfflineTensor {
|
|
// A shell tensor to record shape and dtype
|
|
Tensor shape_tensor{CPU};
|
|
|
|
void setShapeAndType(
|
|
const std::vector<int>& sizes,
|
|
at::Device device,
|
|
caffe2::TypeMeta data_type) {
|
|
shape_tensor.unsafeGetTensorImpl()->set_storage_and_dtype(
|
|
at::Storage::create_legacy(device), data_type);
|
|
shape_tensor.Resize(sizes);
|
|
CHECK(!shape_tensor.storage_initialized());
|
|
CHECK(shape_tensor.dtype_initialized());
|
|
}
|
|
};
|
|
|
|
class OfflineTensorShapeFunctions : public ExternalTensorFunctionsBase {
|
|
public:
|
|
explicit OfflineTensorShapeFunctions() : ExternalTensorFunctionsBase() {}
|
|
~OfflineTensorShapeFunctions() override {}
|
|
bool isQuantized() const override {
|
|
return false;
|
|
}
|
|
bool IsSameMetaType(TypeIdentifier id) override;
|
|
void SetupExternalTensorDescriptor(
|
|
const Blob* blob,
|
|
std::vector<std::vector<uint64_t>>* shapes,
|
|
std::vector<std::vector<float>>* all_scales,
|
|
std::vector<std::vector<int32_t>>* all_offsets,
|
|
ExternalTensorDescriptor* desc) override;
|
|
void LoadInfoOfBlob(
|
|
const Blob* /* unused */,
|
|
std::vector<float>* /* unused */,
|
|
std::vector<float>* /* unused */,
|
|
uint32_t* /* unused */) override {}
|
|
TypeIdentifier GetTypeMetaId() override;
|
|
TypeMeta GetExternalTensorType(const void* c) override;
|
|
vector<int64_t> GetExternalTensorInfo(
|
|
const void* c,
|
|
size_t* capacity,
|
|
DeviceOption* device) override;
|
|
};
|
|
#endif
|
|
} // namespace caffe2
|