mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Refactor: Move common SymbolicMapTest setup to the fixture.
This change moves the initialization of commonly used `SymbolicExpr` and a sample `SymbolicMap` into the `SymbolicMapTest` fixture to reduce code duplication across tests. PiperOrigin-RevId: 826161168
This commit is contained in:
parent
7736af79a6
commit
8f60516a86
|
|
@ -30,34 +30,44 @@ using ::testing::ElementsAre;
|
|||
|
||||
struct SymbolicMapTest : public ::testing::Test {
|
||||
mlir::MLIRContext mlir_context;
|
||||
SymbolicExprContext ctx{&mlir_context};
|
||||
SymbolicExprContext ctx;
|
||||
SymbolicExpr d0;
|
||||
SymbolicExpr d1;
|
||||
static constexpr int kSampleDims = 2;
|
||||
SymbolicExpr s0;
|
||||
SymbolicExpr s1;
|
||||
static constexpr int kSampleSymbols = 2;
|
||||
SymbolicExpr c2;
|
||||
SymbolicExpr c10;
|
||||
SymbolicMap sample_map;
|
||||
|
||||
SymbolicMapTest()
|
||||
: ctx(&mlir_context),
|
||||
d0(CreateDimExpr(&ctx, 0)),
|
||||
d1(CreateDimExpr(&ctx, 1)),
|
||||
s0(CreateSymbolExpr(&ctx, 0, kSampleDims)),
|
||||
s1(CreateSymbolExpr(&ctx, 1, kSampleDims)),
|
||||
c2(ctx.CreateConstant(2)),
|
||||
c10(ctx.CreateConstant(10)),
|
||||
sample_map(SymbolicMap::Get(&ctx, kSampleDims, kSampleSymbols,
|
||||
{d0 + s0, d1 * s1})) {}
|
||||
};
|
||||
|
||||
TEST_F(SymbolicMapTest, GetSymbolAndDimExpressions) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
|
||||
SymbolicMap map = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
|
||||
EXPECT_EQ(map.GetSymbolExpression(0), s0);
|
||||
EXPECT_EQ(map.GetSymbolExpression(1), s1);
|
||||
EXPECT_EQ(map.GetDimExpression(0), d0);
|
||||
EXPECT_EQ(map.GetDimExpression(1), d1);
|
||||
EXPECT_EQ(sample_map.GetSymbolExpression(0), s0);
|
||||
EXPECT_EQ(sample_map.GetSymbolExpression(1), s1);
|
||||
EXPECT_EQ(sample_map.GetDimExpression(0), d0);
|
||||
EXPECT_EQ(sample_map.GetDimExpression(1), d1);
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ToString) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
|
||||
|
||||
SymbolicMap map = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
|
||||
EXPECT_EQ(map.ToString(), "(d0, d1)[s0, s1] -> ((d0 + s0), (d1 * s1))");
|
||||
EXPECT_EQ(sample_map.ToString(),
|
||||
"(d0, d1)[s0, s1] -> ((d0 + s0), (d1 * s1))");
|
||||
|
||||
SymbolicMap empty_map = SymbolicMap::Get(&ctx, 0, 0, {});
|
||||
EXPECT_EQ(empty_map.ToString(), "()[] -> ()");
|
||||
|
||||
SymbolicMap dims_only = SymbolicMap::Get(&ctx, 2, 0, {d0, d1});
|
||||
SymbolicMap dims_only = SymbolicMap::Get(&ctx, kSampleDims, 0, {d0, d1});
|
||||
EXPECT_EQ(dims_only.ToString(), "(d0, d1)[] -> (d0, d1)");
|
||||
|
||||
SymbolicExpr s0_no_dims =
|
||||
|
|
@ -65,7 +75,7 @@ TEST_F(SymbolicMapTest, ToString) {
|
|||
SymbolicExpr s1_no_dims =
|
||||
CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/0);
|
||||
SymbolicMap symbols_only =
|
||||
SymbolicMap::Get(&ctx, 0, 2, {s0_no_dims, s1_no_dims});
|
||||
SymbolicMap::Get(&ctx, 0, kSampleSymbols, {s0_no_dims, s1_no_dims});
|
||||
EXPECT_EQ(symbols_only.ToString(), "()[s0, s1] -> (s0, s1)");
|
||||
}
|
||||
|
||||
|
|
@ -120,18 +130,11 @@ TEST_F(SymbolicMapTest, GetConstantResults) {
|
|||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
|
||||
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||
SymbolicExpr c3 = ctx.CreateConstant(30);
|
||||
|
||||
SymbolicMap map_basic = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
|
||||
SymbolicMap replaced_basic = map_basic.ReplaceDimsAndSymbols(
|
||||
{c1, c2}, {c3, d0}, map_basic.GetNumDims(), map_basic.GetNumSymbols());
|
||||
EXPECT_THAT(replaced_basic.GetResults(), ElementsAre(c1 + c3, c2 * d0));
|
||||
SymbolicMap replaced_basic = sample_map.ReplaceDimsAndSymbols(
|
||||
{c10, c2}, {c3, d0}, sample_map.GetNumDims(), sample_map.GetNumSymbols());
|
||||
EXPECT_THAT(replaced_basic.GetResults(), ElementsAre(c10 + c3, c2 * d0));
|
||||
|
||||
SymbolicMap map_empty = SymbolicMap::Get(&ctx, 0, 0, {});
|
||||
SymbolicMap replaced_empty = map_empty.ReplaceDimsAndSymbols({}, {}, 0, 0);
|
||||
|
|
@ -143,53 +146,28 @@ TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
|
|||
SymbolicExpr new_d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr new_s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
|
||||
SymbolicMap replaced_change_dims = map_change_dims.ReplaceDimsAndSymbols(
|
||||
{new_d0 * c1 + new_d1}, {new_s0}, 2, 1);
|
||||
{new_d0 * c10 + new_d1}, {new_s0}, 2, 1);
|
||||
EXPECT_EQ(replaced_change_dims.GetNumDims(), 2);
|
||||
EXPECT_EQ(replaced_change_dims.GetNumSymbols(), 1);
|
||||
EXPECT_THAT(replaced_change_dims.GetResults(),
|
||||
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
|
||||
ElementsAre((new_d0 * c10 + new_d1) + new_s0 * c2));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlyDims) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
int num_dims = 2;
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||
int num_symbols = 2;
|
||||
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||
|
||||
SymbolicMap map =
|
||||
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{c1, c2}, /*sym_replacements=*/{}, num_dims,
|
||||
num_symbols);
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(c1 + s0, c2 * s1));
|
||||
SymbolicMap replaced = sample_map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{c10, c2}, /*sym_replacements=*/{},
|
||||
sample_map.GetNumDims(), sample_map.GetNumSymbols());
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(c10 + s0, c2 * s1));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlySymbols) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
int num_dims = 2;
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||
int num_symbols = 2;
|
||||
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||
|
||||
SymbolicMap map =
|
||||
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{}, /*sym_replacements=*/{c1, c2}, num_dims,
|
||||
num_symbols);
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c1, d1 * c2));
|
||||
SymbolicMap replaced = sample_map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{}, /*sym_replacements=*/{c10, c2},
|
||||
sample_map.GetNumDims(), sample_map.GetNumSymbols());
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c10, d1 * c2));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, Compose) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
|
||||
// Composition without Symbols
|
||||
SymbolicMap map1_no_symbols = SymbolicMap::Get(&ctx, 1, 0, {d0 * 2});
|
||||
SymbolicMap map2_no_symbols = SymbolicMap::Get(&ctx, 1, 0, {d0 + 5});
|
||||
|
|
@ -243,9 +221,6 @@ TEST_F(SymbolicMapTest, Compose) {
|
|||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, Replace) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(2);
|
||||
SymbolicExpr c5 = ctx.CreateConstant(5);
|
||||
|
||||
SymbolicExpr expr0 = (d0 + c2) * d1;
|
||||
|
|
@ -264,15 +239,15 @@ TEST_F(SymbolicMapTest, Replace) {
|
|||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, GetUnusedVariables) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
[[maybe_unused]] SymbolicExpr d2 = CreateDimExpr(&ctx, 2);
|
||||
// d2 is unused.
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/3);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/3);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(2);
|
||||
[[maybe_unused]] SymbolicExpr s0_3dims =
|
||||
CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/3);
|
||||
SymbolicExpr s1_3dims =
|
||||
CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/3);
|
||||
|
||||
// Map with used and unused dims and symbols.
|
||||
SymbolicMap map = SymbolicMap::Get(&ctx, 3, 2, {d0 + s1, d1 * c2});
|
||||
SymbolicMap map = SymbolicMap::Get(&ctx, 3, 2, {d0 + s1_3dims, d1 * c2});
|
||||
|
||||
llvm::SmallBitVector unused_dims = GetUnusedDimensionsBitVector(map);
|
||||
EXPECT_EQ(unused_dims.size(), 3);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user