Add hashing support for SymbolicMap

This change implements AbslHashValue and llvm::hash_value for xla::gpu::SymbolicMap.

This is a prerequisite for correctly implementing AbslHashValue for xla::IndexingMap after its internal migration to use SymbolicMap. Specifically, it needs be used in IndexingMap::AbslHashValue.

PiperOrigin-RevId: 826038011
This commit is contained in:
A. Unique TensorFlower 2025-10-30 08:14:08 -07:00 committed by TensorFlower Gardener
parent bec8916f32
commit fd85062199
5 changed files with 80 additions and 1 deletions

View File

@ -748,6 +748,7 @@ xla_cc_test(
deps = [
":symbolic_expr",
":symbolic_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@ -776,8 +777,9 @@ xla_cc_test(
name = "symbolic_expr_test",
srcs = ["symbolic_expr_test.cc"],
deps = [
":indexing_test_utils",
":symbolic_expr",
"//xla/hlo/analysis:indexing_test_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",

View File

@ -147,6 +147,11 @@ inline ::llvm::hash_code hash_value(SymbolicExpr expr) {
return ::llvm::hash_value(expr.GetImpl());
}
template <typename H>
H AbslHashValue(H h, const SymbolicExpr& expr) {
return H::combine(std::move(h), hash_value(expr));
}
class SymbolicExprContext {
public:
explicit SymbolicExprContext(mlir::MLIRContext* mlir_context);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/DenseMap.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
@ -320,5 +321,36 @@ TEST_F(SymbolicExprTest, Walk) {
"v1", "((v0 + 42) * v1)"));
}
TEST_F(SymbolicExprTest, Hashing) {
absl::flat_hash_set<SymbolicExpr> set;
SymbolicExpr c42_1 = ctx.CreateConstant(42);
SymbolicExpr c42_2 = ctx.CreateConstant(42);
SymbolicExpr c3 = ctx.CreateConstant(3);
set.insert(c42_1);
set.insert(c42_2);
set.insert(c3);
EXPECT_EQ(set.size(), 2);
SymbolicExpr v0_1 = ctx.CreateVariable(0);
SymbolicExpr v0_2 = ctx.CreateVariable(0);
SymbolicExpr v1 = ctx.CreateVariable(1);
set.insert(v0_1);
set.insert(v0_2);
set.insert(v1);
EXPECT_EQ(set.size(), 4);
SymbolicExpr add1 = v0_1 + c42_1;
SymbolicExpr add2 = v0_2 + c42_2;
SymbolicExpr add3 = v1 + c3;
set.insert(add1);
set.insert(add2);
set.insert(add3);
EXPECT_EQ(set.size(), 6);
}
} // namespace
} // namespace xla

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <string>
#include "absl/types/span.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/hlo/analysis/symbolic_expr.h"
@ -95,6 +96,18 @@ class SymbolicMap {
bool operator==(const SymbolicMap& other) const;
bool operator!=(const SymbolicMap& other) const { return !(*this == other); }
template <typename H>
friend H AbslHashValue(H h, const SymbolicMap& map) {
return H::combine(std::move(h), map.num_dimensions_, map.num_symbols_,
map.exprs_);
}
friend ::llvm::hash_code hash_value(const SymbolicMap& map) {
return ::llvm::hash_combine(
map.num_dimensions_, map.num_symbols_,
::llvm::hash_combine_range(map.exprs_.begin(), map.exprs_.end()));
}
template <typename Sink>
friend void AbslStringify(Sink& sink, const SymbolicMap& map) {
sink.Append(map.ToString());
@ -128,6 +141,11 @@ SymbolicMap CompressDims(const SymbolicMap& map,
SymbolicMap CompressSymbols(const SymbolicMap& map,
const llvm::SmallBitVector& unused_symbols);
template <typename H>
H AbslHashValue(H h, const llvm::SmallVector<SymbolicExpr>& vec) {
return H::combine(std::move(h), absl::MakeSpan(vec));
}
} // namespace xla
#endif // XLA_HLO_ANALYSIS_SYMBOLIC_MAP_H_

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "llvm/ADT/SmallBitVector.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/analysis/symbolic_expr.h"
@ -328,6 +329,27 @@ TEST_F(SymbolicMapTest, CompressSymbols) {
"Attempting to compress a used symbol: 2");
}
TEST_F(SymbolicMapTest, Hashing) {
absl::flat_hash_set<SymbolicMap> set;
SymbolicExpr d0 = ctx.CreateVariable(0);
SymbolicExpr d1 = ctx.CreateVariable(1);
SymbolicExpr s0 = ctx.CreateVariable(2);
SymbolicExpr c42 = ctx.CreateConstant(42);
SymbolicExpr c99 = ctx.CreateConstant(99);
SymbolicMap map1 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c42});
SymbolicMap map2 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c42});
SymbolicMap map3 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c99});
set.insert(map1);
EXPECT_EQ(set.size(), 1);
set.insert(map2);
EXPECT_EQ(set.size(), 1);
set.insert(map3);
EXPECT_EQ(set.size(), 2);
}
} // namespace
} // namespace gpu
} // namespace xla