mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Allow empty dimension list in SymbolicMap::ReplaceDimsAndSymbols
I originally assumed the caller was always providing a full list of replacements but IndexingMap have some uses where the dim_replacement list is empty, resulting in a CHECK-fail. So, I'm allowing the user to provide either dim or symbol empty lists to ReplaceDimsAndSymbols. In that case, the dims/symbols won't be replaced. PiperOrigin-RevId: 826138814
This commit is contained in:
parent
175774337e
commit
7e7b1a3015
18
third_party/xla/xla/hlo/analysis/symbolic_map.cc
vendored
18
third_party/xla/xla/hlo/analysis/symbolic_map.cc
vendored
|
|
@ -127,13 +127,27 @@ SymbolicMap SymbolicMap::ReplaceDimsAndSymbols(
|
|||
absl::Span<const SymbolicExpr> dim_replacements,
|
||||
absl::Span<const SymbolicExpr> sym_replacements, int64_t num_result_dims,
|
||||
int64_t num_result_symbols) const {
|
||||
CHECK_EQ(dim_replacements.size(), num_dimensions_);
|
||||
CHECK_EQ(sym_replacements.size(), num_symbols_);
|
||||
CHECK(dim_replacements.empty() || dim_replacements.size() == num_dimensions_);
|
||||
CHECK(sym_replacements.empty() || sym_replacements.size() == num_symbols_);
|
||||
|
||||
llvm::SmallVector<SymbolicExpr> all_replacements;
|
||||
all_replacements.reserve(num_dimensions_ + num_symbols_);
|
||||
|
||||
if (!dim_replacements.empty()) {
|
||||
absl::c_copy(dim_replacements, std::back_inserter(all_replacements));
|
||||
} else {
|
||||
for (int i = 0; i < num_dimensions_; ++i) {
|
||||
all_replacements.push_back(ctx_->CreateVariable(i));
|
||||
}
|
||||
}
|
||||
|
||||
if (!sym_replacements.empty()) {
|
||||
absl::c_copy(sym_replacements, std::back_inserter(all_replacements));
|
||||
} else {
|
||||
for (int i = 0; i < num_symbols_; ++i) {
|
||||
all_replacements.push_back(ctx_->CreateVariable(num_dimensions_ + i));
|
||||
}
|
||||
}
|
||||
|
||||
llvm::SmallVector<SymbolicExpr> new_exprs;
|
||||
new_exprs.reserve(exprs_.size());
|
||||
|
|
|
|||
|
|
@ -150,6 +150,42 @@ TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
|
|||
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlyDims) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
int num_dims = 2;
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||
int num_symbols = 2;
|
||||
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||
|
||||
SymbolicMap map =
|
||||
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{c1, c2}, /*sym_replacements=*/{}, num_dims,
|
||||
num_symbols);
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(c1 + s0, c2 * s1));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlySymbols) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
int num_dims = 2;
|
||||
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||
int num_symbols = 2;
|
||||
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||
|
||||
SymbolicMap map =
|
||||
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||
/*dim_replacements=*/{}, /*sym_replacements=*/{c1, c2}, num_dims,
|
||||
num_symbols);
|
||||
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c1, d1 * c2));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, Compose) {
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
|
|
@ -207,8 +243,8 @@ TEST_F(SymbolicMapTest, Compose) {
|
|||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, Replace) {
|
||||
SymbolicExpr d0 = ctx.CreateVariable(0);
|
||||
SymbolicExpr d1 = ctx.CreateVariable(1);
|
||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||
SymbolicExpr c2 = ctx.CreateConstant(2);
|
||||
SymbolicExpr c5 = ctx.CreateConstant(5);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user