#pragma once #include #include "caffe2/core/operator.h" #include "caffe2/core/tensor.h" namespace caffe2 { #ifndef C10_MOBILE struct TORCH_API OfflineTensor { // A shell tensor to record shape and dtype Tensor shape_tensor{CPU}; void setShapeAndType( const std::vector& sizes, at::Device device, caffe2::TypeMeta data_type) { shape_tensor.unsafeGetTensorImpl()->set_storage_and_dtype( at::Storage::create_legacy(device), data_type); shape_tensor.Resize(sizes); CHECK(!shape_tensor.storage_initialized()); CHECK(shape_tensor.dtype_initialized()); } }; class OfflineTensorShapeFunctions : public ExternalTensorFunctionsBase { public: explicit OfflineTensorShapeFunctions() : ExternalTensorFunctionsBase() {} ~OfflineTensorShapeFunctions() override {} bool isQuantized() const override { return false; } bool IsSameMetaType(TypeIdentifier id) override; void SetupExternalTensorDescriptor( const Blob* blob, std::vector>* shapes, std::vector>* all_scales, std::vector>* all_offsets, ExternalTensorDescriptor* desc) override; void LoadInfoOfBlob( const Blob* /* unused */, std::vector* /* unused */, std::vector* /* unused */, uint32_t* /* unused */) override {} TypeIdentifier GetTypeMetaId() override; TypeMeta GetExternalTensorType(const void* c) override; vector GetExternalTensorInfo( const void* c, size_t* capacity, DeviceOption* device) override; }; #endif } // namespace caffe2