[XLA:HLO Diff] Consolidate the diff enums.

The `DiffCode` and `DiffType` enums are now consolidated into a single `DiffType` enum.

This change improves the code's clarity and reduces redundancies.

PiperOrigin-RevId: 803112761
This commit is contained in:
Daniel Chen 2025-09-04 11:42:58 -07:00 committed by TensorFlower Gardener
parent 227679f7fd
commit 98a6b660cd
4 changed files with 39 additions and 43 deletions

View File

@ -45,33 +45,33 @@ bool IsChangedInstruction(const HloInstructionNode* left_node,
void DiffResult::AddUnchangedInstruction(const HloInstruction* left,
const HloInstruction* right) {
unchanged_instructions[left] = right;
left_diff_codes[left] = DiffCode::kUnchanged;
right_diff_codes[right] = DiffCode::kUnchanged;
left_diff_codes[left] = DiffType::kUnchanged;
right_diff_codes[right] = DiffType::kUnchanged;
}
void DiffResult::AddChangedInstruction(const HloInstruction* left,
const HloInstruction* right) {
changed_instructions[left] = right;
left_diff_codes[left] = DiffCode::kChanged;
right_diff_codes[right] = DiffCode::kChanged;
left_diff_codes[left] = DiffType::kChanged;
right_diff_codes[right] = DiffType::kChanged;
}
void DiffResult::AddMovedInstruction(const HloInstruction* left,
const HloInstruction* right) {
moved_instructions[left] = right;
left_diff_codes[left] = DiffCode::kUnchanged;
right_diff_codes[right] = DiffCode::kUnchanged;
left_diff_codes[left] = DiffType::kUnchanged;
right_diff_codes[right] = DiffType::kUnchanged;
}
void DiffResult::AddUnmatchedInstruction(const HloInstruction* left,
const HloInstruction* right) {
if (left != nullptr) {
left_module_unmatched_instructions.insert(left);
left_diff_codes[left] = DiffCode::kUnmatched;
left_diff_codes[left] = DiffType::kUnmatched;
}
if (right != nullptr) {
right_module_unmatched_instructions.insert(right);
right_diff_codes[right] = DiffCode::kUnmatched;
right_diff_codes[right] = DiffType::kUnmatched;
}
}

View File

@ -33,11 +33,7 @@
namespace xla {
namespace hlo_diff {
enum DiffCode : uint8_t {
kUnchanged,
kChanged,
kUnmatched,
};
enum DiffType : uint8_t { kUnchanged, kChanged, kUnmatched, kMoved };
// Result of diff'ng the left and right HLO modules. Contains the matched and
// unmatched instructions in the two modules.
@ -56,8 +52,8 @@ struct DiffResult {
right_module_unmatched_instructions;
// Diff codes.
absl::flat_hash_map<const HloInstruction*, DiffCode> left_diff_codes;
absl::flat_hash_map<const HloInstruction*, DiffCode> right_diff_codes;
absl::flat_hash_map<const HloInstruction*, DiffType> left_diff_codes;
absl::flat_hash_map<const HloInstruction*, DiffType> right_diff_codes;
// Debug info.
absl::flat_hash_map<std::pair<const HloInstruction*, const HloInstruction*>,

View File

@ -106,19 +106,19 @@ ENTRY entry {
EXPECT_THAT(diff_result->left_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kChanged)));
DiffType::kChanged)));
EXPECT_THAT(diff_result->right_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kChanged)));
DiffType::kChanged)));
}
TEST_F(HloDiffTest, MatchedDifferentFingerprintMarkAsChanged) {
@ -190,19 +190,19 @@ ENTRY entry {
EXPECT_THAT(diff_result->left_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kChanged)));
DiffType::kChanged)));
EXPECT_THAT(diff_result->right_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kChanged),
DiffType::kChanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kChanged)));
DiffType::kChanged)));
}
TEST_F(HloDiffTest, UnmatchedInstructionsMarkAsUnmatched) {
@ -264,19 +264,19 @@ ENTRY entry {
EXPECT_THAT(diff_result->left_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kUnmatched),
DiffType::kUnmatched),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kUnmatched),
DiffType::kUnmatched),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kUnchanged)));
DiffType::kUnchanged)));
EXPECT_THAT(diff_result->right_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kUnmatched),
DiffType::kUnmatched),
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
DiffCode::kUnmatched),
DiffType::kUnmatched),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kUnchanged)));
DiffType::kUnchanged)));
}
TEST_F(HloDiffTest, ShortFormConstantsMatched) {
@ -354,20 +354,20 @@ ENTRY entry {
diff_result->left_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "constant.2958")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kUnchanged)));
DiffType::kUnchanged)));
EXPECT_THAT(
diff_result->right_diff_codes,
UnorderedElementsAre(
Pair(Pointee(Property(&HloInstruction::name, "constant.2958")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
DiffCode::kUnchanged),
DiffType::kUnchanged),
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
DiffCode::kUnchanged)));
DiffType::kUnchanged)));
}
TEST_F(HloDiffTest, DiffResultToAndFromProtoWorks) {

View File

@ -123,16 +123,16 @@ struct DiffFingerprint {
DiffFingerprint ComputationDiffFingerprint(
const xla::HloComputation* computation,
const absl::flat_hash_map<const HloInstruction*, DiffCode>& diff_codes) {
const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes) {
absl::flat_hash_map<const HloInstruction*, uint64_t> subgraph_fingerprint;
bool all_unchanged = true;
for (auto* instruction : computation->MakeInstructionPostOrder()) {
uint64_t fp = static_cast<uint64_t>(instruction->opcode());
uint64_t diff_type_fp = DiffCode::kUnchanged;
uint64_t diff_type_fp = DiffType::kUnchanged;
if (auto it = diff_codes.find(instruction); it != diff_codes.end()) {
diff_type_fp = it->second;
}
all_unchanged = all_unchanged && (diff_type_fp == DiffCode::kUnchanged);
all_unchanged = all_unchanged && (diff_type_fp == DiffType::kUnchanged);
fp = tsl::FingerprintCat64(fp, diff_type_fp);
for (const HloInstruction* operand : instruction->operands()) {
fp = tsl::FingerprintCat64(fp, subgraph_fingerprint.at(operand));
@ -262,7 +262,7 @@ std::vector<ComputationDiffPattern> FindComputationDiffPatterns(
// Summarizes all computations in the given graph.
ComputationSummaryMap SummarizeAllComputationsInGraph(
const HloModule& module, const InstructionBimap& mapping,
const absl::flat_hash_map<const HloInstruction*, DiffCode>& diff_codes,
const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes,
DiffSide side) {
ComputationSummaryMap result;
for (const HloComputation* computation : module.computations()) {