Allow empty dimension list in SymbolicMap::ReplaceDimsAndSymbols

I originally assumed the caller was always providing a full list of replacements but IndexingMap have some uses where the dim_replacement list is empty, resulting in a CHECK-fail.

So, I'm allowing the user to provide either dim or symbol empty lists to ReplaceDimsAndSymbols. In that case, the dims/symbols won't be replaced.

PiperOrigin-RevId: 826138814
This commit is contained in:
A. Unique TensorFlower 2025-10-30 12:24:13 -07:00 committed by TensorFlower Gardener
parent 175774337e
commit 7e7b1a3015
2 changed files with 56 additions and 6 deletions

View File

@ -127,13 +127,27 @@ SymbolicMap SymbolicMap::ReplaceDimsAndSymbols(
absl::Span<const SymbolicExpr> dim_replacements, absl::Span<const SymbolicExpr> dim_replacements,
absl::Span<const SymbolicExpr> sym_replacements, int64_t num_result_dims, absl::Span<const SymbolicExpr> sym_replacements, int64_t num_result_dims,
int64_t num_result_symbols) const { int64_t num_result_symbols) const {
CHECK_EQ(dim_replacements.size(), num_dimensions_); CHECK(dim_replacements.empty() || dim_replacements.size() == num_dimensions_);
CHECK_EQ(sym_replacements.size(), num_symbols_); CHECK(sym_replacements.empty() || sym_replacements.size() == num_symbols_);
llvm::SmallVector<SymbolicExpr> all_replacements; llvm::SmallVector<SymbolicExpr> all_replacements;
all_replacements.reserve(num_dimensions_ + num_symbols_); all_replacements.reserve(num_dimensions_ + num_symbols_);
absl::c_copy(dim_replacements, std::back_inserter(all_replacements));
absl::c_copy(sym_replacements, std::back_inserter(all_replacements)); if (!dim_replacements.empty()) {
absl::c_copy(dim_replacements, std::back_inserter(all_replacements));
} else {
for (int i = 0; i < num_dimensions_; ++i) {
all_replacements.push_back(ctx_->CreateVariable(i));
}
}
if (!sym_replacements.empty()) {
absl::c_copy(sym_replacements, std::back_inserter(all_replacements));
} else {
for (int i = 0; i < num_symbols_; ++i) {
all_replacements.push_back(ctx_->CreateVariable(num_dimensions_ + i));
}
}
llvm::SmallVector<SymbolicExpr> new_exprs; llvm::SmallVector<SymbolicExpr> new_exprs;
new_exprs.reserve(exprs_.size()); new_exprs.reserve(exprs_.size());

View File

@ -150,6 +150,42 @@ TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2)); ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
} }
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlyDims) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
int num_dims = 2;
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
int num_symbols = 2;
SymbolicExpr c1 = ctx.CreateConstant(10);
SymbolicExpr c2 = ctx.CreateConstant(20);
SymbolicMap map =
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{c1, c2}, /*sym_replacements=*/{}, num_dims,
num_symbols);
EXPECT_THAT(replaced.GetResults(), ElementsAre(c1 + s0, c2 * s1));
}
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlySymbols) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
int num_dims = 2;
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
int num_symbols = 2;
SymbolicExpr c1 = ctx.CreateConstant(10);
SymbolicExpr c2 = ctx.CreateConstant(20);
SymbolicMap map =
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
/*dim_replacements=*/{}, /*sym_replacements=*/{c1, c2}, num_dims,
num_symbols);
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c1, d1 * c2));
}
TEST_F(SymbolicMapTest, Compose) { TEST_F(SymbolicMapTest, Compose) {
SymbolicExpr d0 = CreateDimExpr(&ctx, 0); SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = CreateDimExpr(&ctx, 1); SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
@ -207,8 +243,8 @@ TEST_F(SymbolicMapTest, Compose) {
} }
TEST_F(SymbolicMapTest, Replace) { TEST_F(SymbolicMapTest, Replace) {
SymbolicExpr d0 = ctx.CreateVariable(0); SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
SymbolicExpr d1 = ctx.CreateVariable(1); SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
SymbolicExpr c2 = ctx.CreateConstant(2); SymbolicExpr c2 = ctx.CreateConstant(2);
SymbolicExpr c5 = ctx.CreateConstant(5); SymbolicExpr c5 = ctx.CreateConstant(5);