mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add hashing support for SymbolicMap
This change implements AbslHashValue and llvm::hash_value for xla::gpu::SymbolicMap. This is a prerequisite for correctly implementing AbslHashValue for xla::IndexingMap after its internal migration to use SymbolicMap. Specifically, it needs be used in IndexingMap::AbslHashValue. PiperOrigin-RevId: 826038011
This commit is contained in:
parent
bec8916f32
commit
fd85062199
4
third_party/xla/xla/hlo/analysis/BUILD
vendored
4
third_party/xla/xla/hlo/analysis/BUILD
vendored
|
|
@ -748,6 +748,7 @@ xla_cc_test(
|
|||
deps = [
|
||||
":symbolic_expr",
|
||||
":symbolic_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
|
@ -776,8 +777,9 @@ xla_cc_test(
|
|||
name = "symbolic_expr_test",
|
||||
srcs = ["symbolic_expr_test.cc"],
|
||||
deps = [
|
||||
":indexing_test_utils",
|
||||
":symbolic_expr",
|
||||
"//xla/hlo/analysis:indexing_test_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
|
|
|||
|
|
@ -147,6 +147,11 @@ inline ::llvm::hash_code hash_value(SymbolicExpr expr) {
|
|||
return ::llvm::hash_value(expr.GetImpl());
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const SymbolicExpr& expr) {
|
||||
return H::combine(std::move(h), hash_value(expr));
|
||||
}
|
||||
|
||||
class SymbolicExprContext {
|
||||
public:
|
||||
explicit SymbolicExprContext(mlir::MLIRContext* mlir_context);
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
@ -320,5 +321,36 @@ TEST_F(SymbolicExprTest, Walk) {
|
|||
"v1", "((v0 + 42) * v1)"));
|
||||
}
|
||||
|
||||
TEST_F(SymbolicExprTest, Hashing) {
|
||||
absl::flat_hash_set<SymbolicExpr> set;
|
||||
|
||||
SymbolicExpr c42_1 = ctx.CreateConstant(42);
|
||||
SymbolicExpr c42_2 = ctx.CreateConstant(42);
|
||||
SymbolicExpr c3 = ctx.CreateConstant(3);
|
||||
|
||||
set.insert(c42_1);
|
||||
set.insert(c42_2);
|
||||
set.insert(c3);
|
||||
EXPECT_EQ(set.size(), 2);
|
||||
|
||||
SymbolicExpr v0_1 = ctx.CreateVariable(0);
|
||||
SymbolicExpr v0_2 = ctx.CreateVariable(0);
|
||||
SymbolicExpr v1 = ctx.CreateVariable(1);
|
||||
|
||||
set.insert(v0_1);
|
||||
set.insert(v0_2);
|
||||
set.insert(v1);
|
||||
EXPECT_EQ(set.size(), 4);
|
||||
|
||||
SymbolicExpr add1 = v0_1 + c42_1;
|
||||
SymbolicExpr add2 = v0_2 + c42_2;
|
||||
SymbolicExpr add3 = v1 + c3;
|
||||
|
||||
set.insert(add1);
|
||||
set.insert(add2);
|
||||
set.insert(add3);
|
||||
EXPECT_EQ(set.size(), 6);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
18
third_party/xla/xla/hlo/analysis/symbolic_map.h
vendored
18
third_party/xla/xla/hlo/analysis/symbolic_map.h
vendored
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "xla/hlo/analysis/symbolic_expr.h"
|
||||
|
|
@ -95,6 +96,18 @@ class SymbolicMap {
|
|||
bool operator==(const SymbolicMap& other) const;
|
||||
bool operator!=(const SymbolicMap& other) const { return !(*this == other); }
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const SymbolicMap& map) {
|
||||
return H::combine(std::move(h), map.num_dimensions_, map.num_symbols_,
|
||||
map.exprs_);
|
||||
}
|
||||
|
||||
friend ::llvm::hash_code hash_value(const SymbolicMap& map) {
|
||||
return ::llvm::hash_combine(
|
||||
map.num_dimensions_, map.num_symbols_,
|
||||
::llvm::hash_combine_range(map.exprs_.begin(), map.exprs_.end()));
|
||||
}
|
||||
|
||||
template <typename Sink>
|
||||
friend void AbslStringify(Sink& sink, const SymbolicMap& map) {
|
||||
sink.Append(map.ToString());
|
||||
|
|
@ -128,6 +141,11 @@ SymbolicMap CompressDims(const SymbolicMap& map,
|
|||
SymbolicMap CompressSymbols(const SymbolicMap& map,
|
||||
const llvm::SmallBitVector& unused_symbols);
|
||||
|
||||
template <typename H>
|
||||
H AbslHashValue(H h, const llvm::SmallVector<SymbolicExpr>& vec) {
|
||||
return H::combine(std::move(h), absl::MakeSpan(vec));
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_HLO_ANALYSIS_SYMBOLIC_MAP_H_
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "llvm/ADT/SmallBitVector.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "xla/hlo/analysis/symbolic_expr.h"
|
||||
|
|
@ -328,6 +329,27 @@ TEST_F(SymbolicMapTest, CompressSymbols) {
|
|||
"Attempting to compress a used symbol: 2");
|
||||
}
|
||||
|
||||
TEST_F(SymbolicMapTest, Hashing) {
|
||||
absl::flat_hash_set<SymbolicMap> set;
|
||||
|
||||
SymbolicExpr d0 = ctx.CreateVariable(0);
|
||||
SymbolicExpr d1 = ctx.CreateVariable(1);
|
||||
SymbolicExpr s0 = ctx.CreateVariable(2);
|
||||
SymbolicExpr c42 = ctx.CreateConstant(42);
|
||||
SymbolicExpr c99 = ctx.CreateConstant(99);
|
||||
|
||||
SymbolicMap map1 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c42});
|
||||
SymbolicMap map2 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c42});
|
||||
SymbolicMap map3 = SymbolicMap::Get(&ctx, 2, 1, {d0 + s0, d1 * c99});
|
||||
|
||||
set.insert(map1);
|
||||
EXPECT_EQ(set.size(), 1);
|
||||
set.insert(map2);
|
||||
EXPECT_EQ(set.size(), 1);
|
||||
set.insert(map3);
|
||||
EXPECT_EQ(set.size(), 2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user