mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Allow empty dimension list in SymbolicMap::ReplaceDimsAndSymbols
I originally assumed the caller was always providing a full list of replacements but IndexingMap have some uses where the dim_replacement list is empty, resulting in a CHECK-fail. So, I'm allowing the user to provide either dim or symbol empty lists to ReplaceDimsAndSymbols. In that case, the dims/symbols won't be replaced. PiperOrigin-RevId: 826138814
This commit is contained in:
parent
175774337e
commit
7e7b1a3015
22
third_party/xla/xla/hlo/analysis/symbolic_map.cc
vendored
22
third_party/xla/xla/hlo/analysis/symbolic_map.cc
vendored
|
|
@ -127,13 +127,27 @@ SymbolicMap SymbolicMap::ReplaceDimsAndSymbols(
|
||||||
absl::Span<const SymbolicExpr> dim_replacements,
|
absl::Span<const SymbolicExpr> dim_replacements,
|
||||||
absl::Span<const SymbolicExpr> sym_replacements, int64_t num_result_dims,
|
absl::Span<const SymbolicExpr> sym_replacements, int64_t num_result_dims,
|
||||||
int64_t num_result_symbols) const {
|
int64_t num_result_symbols) const {
|
||||||
CHECK_EQ(dim_replacements.size(), num_dimensions_);
|
CHECK(dim_replacements.empty() || dim_replacements.size() == num_dimensions_);
|
||||||
CHECK_EQ(sym_replacements.size(), num_symbols_);
|
CHECK(sym_replacements.empty() || sym_replacements.size() == num_symbols_);
|
||||||
|
|
||||||
llvm::SmallVector<SymbolicExpr> all_replacements;
|
llvm::SmallVector<SymbolicExpr> all_replacements;
|
||||||
all_replacements.reserve(num_dimensions_ + num_symbols_);
|
all_replacements.reserve(num_dimensions_ + num_symbols_);
|
||||||
absl::c_copy(dim_replacements, std::back_inserter(all_replacements));
|
|
||||||
absl::c_copy(sym_replacements, std::back_inserter(all_replacements));
|
if (!dim_replacements.empty()) {
|
||||||
|
absl::c_copy(dim_replacements, std::back_inserter(all_replacements));
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < num_dimensions_; ++i) {
|
||||||
|
all_replacements.push_back(ctx_->CreateVariable(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sym_replacements.empty()) {
|
||||||
|
absl::c_copy(sym_replacements, std::back_inserter(all_replacements));
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < num_symbols_; ++i) {
|
||||||
|
all_replacements.push_back(ctx_->CreateVariable(num_dimensions_ + i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
llvm::SmallVector<SymbolicExpr> new_exprs;
|
llvm::SmallVector<SymbolicExpr> new_exprs;
|
||||||
new_exprs.reserve(exprs_.size());
|
new_exprs.reserve(exprs_.size());
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,42 @@ TEST_F(SymbolicMapTest, ReplaceDimsAndSymbols) {
|
||||||
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
|
ElementsAre((new_d0 * c1 + new_d1) + new_s0 * c2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlyDims) {
|
||||||
|
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||||
|
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||||
|
int num_dims = 2;
|
||||||
|
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||||
|
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||||
|
int num_symbols = 2;
|
||||||
|
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||||
|
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||||
|
|
||||||
|
SymbolicMap map =
|
||||||
|
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||||
|
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||||
|
/*dim_replacements=*/{c1, c2}, /*sym_replacements=*/{}, num_dims,
|
||||||
|
num_symbols);
|
||||||
|
EXPECT_THAT(replaced.GetResults(), ElementsAre(c1 + s0, c2 * s1));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SymbolicMapTest, ReplaceDimsAndSymbolsOnlySymbols) {
|
||||||
|
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||||
|
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||||
|
int num_dims = 2;
|
||||||
|
SymbolicExpr s0 = CreateSymbolExpr(&ctx, 0, num_dims);
|
||||||
|
SymbolicExpr s1 = CreateSymbolExpr(&ctx, 1, num_dims);
|
||||||
|
int num_symbols = 2;
|
||||||
|
SymbolicExpr c1 = ctx.CreateConstant(10);
|
||||||
|
SymbolicExpr c2 = ctx.CreateConstant(20);
|
||||||
|
|
||||||
|
SymbolicMap map =
|
||||||
|
SymbolicMap::Get(&ctx, num_dims, num_symbols, {d0 + s0, d1 * s1});
|
||||||
|
SymbolicMap replaced = map.ReplaceDimsAndSymbols(
|
||||||
|
/*dim_replacements=*/{}, /*sym_replacements=*/{c1, c2}, num_dims,
|
||||||
|
num_symbols);
|
||||||
|
EXPECT_THAT(replaced.GetResults(), ElementsAre(d0 + c1, d1 * c2));
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(SymbolicMapTest, Compose) {
|
TEST_F(SymbolicMapTest, Compose) {
|
||||||
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||||
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||||
|
|
@ -207,8 +243,8 @@ TEST_F(SymbolicMapTest, Compose) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SymbolicMapTest, Replace) {
|
TEST_F(SymbolicMapTest, Replace) {
|
||||||
SymbolicExpr d0 = ctx.CreateVariable(0);
|
SymbolicExpr d0 = CreateDimExpr(&ctx, 0);
|
||||||
SymbolicExpr d1 = ctx.CreateVariable(1);
|
SymbolicExpr d1 = CreateDimExpr(&ctx, 1);
|
||||||
SymbolicExpr c2 = ctx.CreateConstant(2);
|
SymbolicExpr c2 = ctx.CreateConstant(2);
|
||||||
SymbolicExpr c5 = ctx.CreateConstant(5);
|
SymbolicExpr c5 = ctx.CreateConstant(5);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user