Refactor: Move common SymbolicMapTest setup to the fixture.

This change moves the initialization of commonly used `SymbolicExpr` and a sample `SymbolicMap` into the `SymbolicMapTest` fixture to reduce code duplication across tests.

PiperOrigin-RevId: 826161168
This commit is contained in:
A. Unique TensorFlower 2025-10-30 13:23:41 -07:00 committed by TensorFlower Gardener
parent 7736af79a6
commit 8f60516a86

View File

@ -30,34 +30,44 @@ using ::testing::ElementsAre;
struct SymbolicMapTest : public ::testing::Test {
mlir::MLIRContext mlir_context;
SymbolicExprContext ctx{&mlir_context};
SymbolicExprContext ctx;
SymbolicExpr d0;
SymbolicExpr d1;
static constexpr int kSampleDims = 2;
SymbolicExpr s0;
SymbolicExpr s1;
static constexpr int kSampleSymbols = 2;
SymbolicExpr c2;
SymbolicExpr c10;
SymbolicMap sample_map;
SymbolicMapTest()
: ctx(&mlir_context),
d0(CreateDimExpr(&ctx, 0)),
d1(CreateDimExpr(&ctx, 1)),
s0(CreateSymbolExpr(&ctx, 0, kSampleDims)),
s1(CreateSymbolExpr(&ctx, 1, kSampleDims)),
c2(ctx.CreateConstant(2)),
c10(ctx.CreateConstant(10)),
sample_map(SymbolicMap::Get(&ctx, kSampleDims, kSampleSymbols,
{d0 + s0, d1 * s1})) {}
};
TEST_F(SymbolicMapTest, GetSymbolAndDimExpressions) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
SymbolicMap map = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
EXPECT_EQ(map.GetSymbolExpression(0), s0);
EXPECT_EQ(map.GetSymbolExpression(1), s1);
EXPECT_EQ(map.GetDimExpression(0), d0);
EXPECT_EQ(map.GetDimExpression(1), d1);
EXPECT_EQ(sample_map.GetSymbolExpression(0), s0);
EXPECT_EQ(sample_map.GetSymbolExpression(1), s1);
EXPECT_EQ(sample_map.GetDimExpression(0), d0);
EXPECT_EQ(sample_map.GetDimExpression(1), d1);
}
TEST_F(SymbolicMapTest, ToString) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
SymbolicMap map = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
EXPECT_EQ(map.ToString(), "(d0, d1)[s0, s1] -> ((d0 + s0), (d1 * s1))");
EXPECT_EQ(sample_map.ToString(),
"(d0, d1)[s0, s1] -> ((d0 + s0), (d1 * s1))");
SymbolicMap empty_map = SymbolicMap::Get(&ctx, 0, 0, {});
EXPECT_EQ(empty_map.ToString(), "()[] -> ()");
SymbolicMap dims_only = SymbolicMap::Get(&ctx, 2, 0, {d0, d1});
SymbolicMap dims_only = SymbolicMap::Get(&ctx, kSampleDims, 0, {d0, d1});
EXPECT_EQ(dims_only.ToString(), "(d0, d1)[] -> (d0, d1)");
SymbolicExpr s0_no_dims =
@ -65,7 +75,7 @@ TEST_F(SymbolicMapTest, ToString) {
SymbolicExpr s1_no_dims =
CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/0);
SymbolicMap symbols_only =
SymbolicMap::Get(&ctx, 0, 2, {s0_no_dims, s1_no_dims});
SymbolicMap::Get(&ctx, 0, kSampleSymbols, {s0_no_dims, s1_no_dims});
EXPECT_EQ(symbols_only.ToString(), "()[s0, s1] -> (s0, s1)");
}
@ -120,18 +130,11 @@ TEST_F(SymbolicMapTest, GetConstantResults) {
}
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/2);
SymbolicExpr c1 = ctx.CreateConstant(10);
SymbolicExpr c2 = ctx.CreateConstant(20);
SymbolicExpr c3 = ctx.CreateConstant(30);
SymbolicMap map_basic = SymbolicMap::Get(&ctx, 2, 2, {d0 + s0, d1 * s1});
SymbolicMap replaced_basic = map_basic.ReplaceDimsAndSymbols(
{c1, c2}, {c3, d0}, map_basic.GetNumDims(), map_basic.GetNumSymbols());
EXPECT_THAT(replaced_basic.GetResults(), ElementsAre(c1 + c3, c2 * d0));
SymbolicMap replaced_basic = sample_map.ReplaceDimsAndSymbols(
{c10, c2}, {c3, d0}, sample_map.GetNumDims(), sample_map.GetNumSymbols());
EXPECT_THAT(replaced_basic.GetResults(), ElementsAre(c10 + c3, c2 * d0));
SymbolicMap map_empty = SymbolicMap::Get(&ctx, 0, 0, {});
SymbolicMap replaced_empty = map_empty.ReplaceDimsAndSymbols({}, {}, 0, 0);
@ -143,53 +146,28 @@ TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
SymbolicExpr new_d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr new_s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/2);
SymbolicMap replaced_change_dims = map_change_dims.ReplaceDimsAndSymbols(
{new_d0 * c1 + new_d1}, {new_s0}, 2, 1);
{new_d0 * c10 + new_d1}, {new_s0}, 2, 1);
EXPECT_EQ(replaced_change_dims.GetNumDims(), 2);
EXPECT_EQ(replaced_change_dims.GetNumSymbols(), 1);
EXPECT_THAT(replaced_change_dims.GetResults(),
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
ElementsAre((new_d0 * c10 + new_d1) + new_s0 * c2));
}
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlyDims) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
int num_dims = 2;
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
int num_symbols = 2;
SymbolicExpr c1 = ctx.CreateConstant(10);
SymbolicExpr c2 = ctx.CreateConstant(20);
SymbolicMap map =
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{c1, c2}, /*sym_replacements=*/{}, num_dims,
num_symbols);
EXPECT_THAT(replaced.GetResults(), ElementsAre(c1 + s0, c2 * s1));
SymbolicMap replaced = sample_map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{c10, c2}, /*sym_replacements=*/{},
sample_map.GetNumDims(), sample_map.GetNumSymbols());
EXPECT_THAT(replaced.GetResults(), ElementsAre(c10 + s0, c2 * s1));
}
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlySymbols) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
int num_dims = 2;
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
int num_symbols = 2;
SymbolicExpr c1 = ctx.CreateConstant(10);
SymbolicExpr c2 = ctx.CreateConstant(20);
SymbolicMap map =
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{}, /*sym_replacements=*/{c1, c2}, num_dims,
num_symbols);
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c1, d1 * c2));
SymbolicMap replaced = sample_map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{}, /*sym_replacements=*/{c10, c2},
sample_map.GetNumDims(), sample_map.GetNumSymbols());
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c10, d1 * c2));
}
TEST_F(SymbolicMapTest, Compose) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
// Composition without Symbols
SymbolicMap map1_no_symbols = SymbolicMap::Get(&ctx, 1, 0, {d0 * 2});
SymbolicMap map2_no_symbols = SymbolicMap::Get(&ctx, 1, 0, {d0 + 5});
@ -243,9 +221,6 @@ TEST_F(SymbolicMapTest, Compose) {
}
TEST_F(SymbolicMapTest, Replace) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr c2 = ctx.CreateConstant(2);
SymbolicExpr c5 = ctx.CreateConstant(5);
SymbolicExpr expr0 = (d0 + c2) * d1;
@ -264,15 +239,15 @@ TEST_F(SymbolicMapTest, Replace) {
}
TEST_F(SymbolicMapTest, GetUnusedVariables) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
[[maybe_unused]] SymbolicExpr d2 = CreateDimExpr(&ctx, 2);
// d2 is unused.
SymbolicExpr s0 = CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/3);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/3);
SymbolicExpr c2 = ctx.CreateConstant(2);
[[maybe_unused]] SymbolicExpr s0_3dims =
CreateSymbolExpr(&ctx, /*symbol_id=*/0, /*num_dims=*/3);
SymbolicExpr s1_3dims =
CreateSymbolExpr(&ctx, /*symbol_id=*/1, /*num_dims=*/3);
// Map with used and unused dims and symbols.
SymbolicMap map = SymbolicMap::Get(&ctx, 3, 2, {d0 + s1, d1 * c2});
SymbolicMap map = SymbolicMap::Get(&ctx, 3, 2, {d0 + s1_3dims, d1 * c2});
llvm::SmallBitVector unused_dims = GetUnusedDimensionsBitVector(map);
EXPECT_EQ(unused_dims.size(), 3);