mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68314 Add a convenience to lazy::Shape for counting the number of elements (by multiplying out the dimensions). This is a method on Tensor, and in switching other lazy tensor shape utils to use aten shape inference, we need numel counts. Test Plan: add unit tests Reviewed By: alanwaketan Differential Revision: D32409138 fbshipit-source-id: 3ae725300f8826d38e45412f46501d5e5f776fb2
102 lines
2.7 KiB
C++
102 lines
2.7 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <sstream>
|
|
|
|
#include <torch/csrc/lazy/core/shape.h>
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
TEST(ShapeTest, Basic1) {
|
|
auto shape = Shape();
|
|
|
|
EXPECT_STREQ(shape.to_string().c_str(), "UNKNOWN_SCALAR[]");
|
|
EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Undefined);
|
|
EXPECT_EQ(shape.dim(), 0);
|
|
EXPECT_TRUE(shape.sizes().empty());
|
|
EXPECT_THROW(shape.size(0), std::out_of_range);
|
|
}
|
|
|
|
TEST(ShapeTest, Basic2) {
|
|
auto shape = Shape(c10::ScalarType::Float, {1, 2, 3});
|
|
|
|
EXPECT_EQ(shape.numel(), 6);
|
|
EXPECT_STREQ(shape.to_string().c_str(), "Float[1,2,3]");
|
|
EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
|
|
EXPECT_EQ(shape.dim(), 3);
|
|
EXPECT_EQ(shape.sizes().size(), 3);
|
|
for (int64_t i = 0; i < shape.dim(); i++) {
|
|
EXPECT_EQ(shape.sizes()[i], i + 1);
|
|
EXPECT_EQ(shape.size(i), i + 1);
|
|
}
|
|
}
|
|
|
|
TEST(ShapeTest, Basic3) {
|
|
auto shape = Shape(c10::ScalarType::Float, {});
|
|
|
|
EXPECT_STREQ(shape.to_string().c_str(), "Float[]");
|
|
EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
|
|
EXPECT_EQ(shape.dim(), 0);
|
|
// this is surprising, but it's in line with how 0-D tensors behave
|
|
EXPECT_EQ(shape.numel(), 1);
|
|
EXPECT_TRUE(shape.sizes().empty());
|
|
EXPECT_THROW(shape.size(0), std::out_of_range);
|
|
}
|
|
|
|
TEST(ShapeTest, SetScalarType) {
|
|
auto shape = Shape();
|
|
|
|
shape.set_scalar_type(c10::ScalarType::Long);
|
|
EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Long);
|
|
}
|
|
|
|
TEST(ShapeTest, SetSize) {
|
|
auto shape1 = Shape();
|
|
EXPECT_THROW(shape1.set_size(0, 0), std::out_of_range);
|
|
|
|
auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
|
|
shape2.set_size(0, 3);
|
|
EXPECT_EQ(shape2.sizes()[0], 3);
|
|
EXPECT_EQ(shape2.size(0), 3);
|
|
}
|
|
|
|
TEST(ShapeTest, Equal) {
|
|
auto shape1 = Shape(c10::ScalarType::Float, {});
|
|
auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
|
|
auto shape3 = Shape(c10::ScalarType::Long, {1, 2, 3});
|
|
auto shape4 = Shape(c10::ScalarType::Float, {1, 2, 3});
|
|
|
|
EXPECT_FALSE(shape1 == shape2);
|
|
EXPECT_FALSE(shape2 == shape3);
|
|
EXPECT_FALSE(shape1 == shape3);
|
|
EXPECT_TRUE(shape2 == shape2);
|
|
}
|
|
|
|
TEST(ShapeTest, Ostream) {
|
|
auto shape = Shape();
|
|
std::stringstream ss;
|
|
ss << shape;
|
|
|
|
EXPECT_EQ(shape.to_string(), ss.str());
|
|
}
|
|
|
|
TEST(ShapeTest, ConvertShapes) {
|
|
auto shape1 = Shape(c10::ScalarType::Long, {1, 2, 3});
|
|
auto shape2 = Shape(c10::ScalarType::Float, {1, 2});
|
|
|
|
auto shapes1 = convertShapes({}, {});
|
|
EXPECT_TRUE(shapes1.empty());
|
|
|
|
auto shapes2 = convertShapes({c10::ScalarType::Long}, {{1, 2, 3}});
|
|
EXPECT_EQ(shapes2.size(), 1);
|
|
EXPECT_EQ(shapes2[0], shape1);
|
|
|
|
auto shapes3 = convertShapes({c10::ScalarType::Long, c10::ScalarType::Float}, {{1, 2, 3}, {1, 2}});
|
|
EXPECT_EQ(shapes3.size(), 2);
|
|
EXPECT_EQ(shapes3[0], shape1);
|
|
EXPECT_EQ(shapes3[1], shape2);
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|