mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Torch Native Runtime RFC: https://github.com/pytorch/rfcs/pull/72 This diff moves `TensorMeta.cpp` and `TensorMeta.h` to PyTorch core under `torch/nativert/graph/` Existing `torch::_export::TensorMeta` in `torch/csrc/utils/generated_serialization_types.h` is auto-generated from the export serde schema and therefore only containing the most basic serializable types. We need the newly added `TensorMeta.cpp` to deserialize the metadata into a in-memory class with c10 types so that it can be consumed by the runtime later. Test Plan: Added test under `test/cpp/nativert/test_tensor_meta.cpp` Differential Revision: D73820548 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152475 Approved by: https://github.com/albanD
63 lines
2.1 KiB
C++
63 lines
2.1 KiB
C++
#include <gtest/gtest.h>
|
|
#include <torch/nativert/graph/TensorMeta.h>
|
|
|
|
namespace torch::nativert {
|
|
TEST(TensorMetaTest, ScalarTypeConversion) {
|
|
EXPECT_EQ(
|
|
convertJsonScalarType(torch::_export::ScalarType::FLOAT),
|
|
c10::ScalarType::Float);
|
|
EXPECT_EQ(
|
|
convertJsonScalarType(torch::_export::ScalarType::INT),
|
|
c10::ScalarType::Int);
|
|
EXPECT_EQ(
|
|
convertJsonScalarType(torch::_export::ScalarType::HALF),
|
|
c10::ScalarType::Half);
|
|
EXPECT_EQ(
|
|
convertJsonScalarType(torch::_export::ScalarType::COMPLEXHALF),
|
|
c10::ScalarType::ComplexHalf);
|
|
EXPECT_EQ(
|
|
convertJsonScalarType(torch::_export::ScalarType::BFLOAT16),
|
|
c10::ScalarType::BFloat16);
|
|
EXPECT_THROW(
|
|
convertJsonScalarType(static_cast<torch::_export::ScalarType>(100)),
|
|
c10::Error);
|
|
}
|
|
TEST(TensorMetaTest, MemoryFormatConversion) {
|
|
EXPECT_EQ(
|
|
convertJsonMemoryFormat(torch::_export::MemoryFormat::ContiguousFormat),
|
|
c10::MemoryFormat::Contiguous);
|
|
EXPECT_EQ(
|
|
convertJsonMemoryFormat(torch::_export::MemoryFormat::ChannelsLast),
|
|
c10::MemoryFormat::ChannelsLast);
|
|
EXPECT_EQ(
|
|
convertJsonMemoryFormat(torch::_export::MemoryFormat::PreserveFormat),
|
|
c10::MemoryFormat::Preserve);
|
|
EXPECT_THROW(
|
|
convertJsonMemoryFormat(static_cast<torch::_export::MemoryFormat>(100)),
|
|
c10::Error);
|
|
}
|
|
|
|
TEST(TensorMetaTest, LayoutConversion) {
|
|
EXPECT_EQ(
|
|
convertJsonLayout(torch::_export::Layout::Strided), c10::Layout::Strided);
|
|
EXPECT_EQ(
|
|
convertJsonLayout(torch::_export::Layout::SparseCsr),
|
|
c10::Layout::SparseCsr);
|
|
EXPECT_EQ(
|
|
convertJsonLayout(torch::_export::Layout::_mkldnn), c10::Layout::Mkldnn);
|
|
EXPECT_THROW(
|
|
convertJsonLayout(static_cast<torch::_export::Layout>(100)), c10::Error);
|
|
}
|
|
TEST(TensorMetaTest, DeviceConversion) {
|
|
torch::_export::Device cpu_device;
|
|
cpu_device.set_type("cpu");
|
|
EXPECT_EQ(convertJsonDevice(cpu_device), c10::Device(c10::DeviceType::CPU));
|
|
torch::_export::Device cuda_device;
|
|
cuda_device.set_type("cuda");
|
|
cuda_device.set_index(0);
|
|
EXPECT_EQ(
|
|
convertJsonDevice(cuda_device), c10::Device(c10::DeviceType::CUDA, 0));
|
|
}
|
|
|
|
} // namespace torch::nativert
|