mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:HLO Diff] Consolidate the diff enums.
The `DiffCode` and `DiffType` enums are now consolidated into a single `DiffType` enum. This change improves the code's clarity and reduces redundancies. PiperOrigin-RevId: 803112761
This commit is contained in:
parent
227679f7fd
commit
98a6b660cd
|
|
@ -45,33 +45,33 @@ bool IsChangedInstruction(const HloInstructionNode* left_node,
|
|||
void DiffResult::AddUnchangedInstruction(const HloInstruction* left,
|
||||
const HloInstruction* right) {
|
||||
unchanged_instructions[left] = right;
|
||||
left_diff_codes[left] = DiffCode::kUnchanged;
|
||||
right_diff_codes[right] = DiffCode::kUnchanged;
|
||||
left_diff_codes[left] = DiffType::kUnchanged;
|
||||
right_diff_codes[right] = DiffType::kUnchanged;
|
||||
}
|
||||
|
||||
void DiffResult::AddChangedInstruction(const HloInstruction* left,
|
||||
const HloInstruction* right) {
|
||||
changed_instructions[left] = right;
|
||||
left_diff_codes[left] = DiffCode::kChanged;
|
||||
right_diff_codes[right] = DiffCode::kChanged;
|
||||
left_diff_codes[left] = DiffType::kChanged;
|
||||
right_diff_codes[right] = DiffType::kChanged;
|
||||
}
|
||||
|
||||
void DiffResult::AddMovedInstruction(const HloInstruction* left,
|
||||
const HloInstruction* right) {
|
||||
moved_instructions[left] = right;
|
||||
left_diff_codes[left] = DiffCode::kUnchanged;
|
||||
right_diff_codes[right] = DiffCode::kUnchanged;
|
||||
left_diff_codes[left] = DiffType::kUnchanged;
|
||||
right_diff_codes[right] = DiffType::kUnchanged;
|
||||
}
|
||||
|
||||
void DiffResult::AddUnmatchedInstruction(const HloInstruction* left,
|
||||
const HloInstruction* right) {
|
||||
if (left != nullptr) {
|
||||
left_module_unmatched_instructions.insert(left);
|
||||
left_diff_codes[left] = DiffCode::kUnmatched;
|
||||
left_diff_codes[left] = DiffType::kUnmatched;
|
||||
}
|
||||
if (right != nullptr) {
|
||||
right_module_unmatched_instructions.insert(right);
|
||||
right_diff_codes[right] = DiffCode::kUnmatched;
|
||||
right_diff_codes[right] = DiffType::kUnmatched;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,11 +33,7 @@
|
|||
namespace xla {
|
||||
namespace hlo_diff {
|
||||
|
||||
enum DiffCode : uint8_t {
|
||||
kUnchanged,
|
||||
kChanged,
|
||||
kUnmatched,
|
||||
};
|
||||
enum DiffType : uint8_t { kUnchanged, kChanged, kUnmatched, kMoved };
|
||||
|
||||
// Result of diff'ng the left and right HLO modules. Contains the matched and
|
||||
// unmatched instructions in the two modules.
|
||||
|
|
@ -56,8 +52,8 @@ struct DiffResult {
|
|||
right_module_unmatched_instructions;
|
||||
|
||||
// Diff codes.
|
||||
absl::flat_hash_map<const HloInstruction*, DiffCode> left_diff_codes;
|
||||
absl::flat_hash_map<const HloInstruction*, DiffCode> right_diff_codes;
|
||||
absl::flat_hash_map<const HloInstruction*, DiffType> left_diff_codes;
|
||||
absl::flat_hash_map<const HloInstruction*, DiffType> right_diff_codes;
|
||||
|
||||
// Debug info.
|
||||
absl::flat_hash_map<std::pair<const HloInstruction*, const HloInstruction*>,
|
||||
|
|
|
|||
|
|
@ -106,19 +106,19 @@ ENTRY entry {
|
|||
EXPECT_THAT(diff_result->left_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kChanged)));
|
||||
DiffType::kChanged)));
|
||||
EXPECT_THAT(diff_result->right_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kChanged)));
|
||||
DiffType::kChanged)));
|
||||
}
|
||||
|
||||
TEST_F(HloDiffTest, MatchedDifferentFingerprintMarkAsChanged) {
|
||||
|
|
@ -190,19 +190,19 @@ ENTRY entry {
|
|||
EXPECT_THAT(diff_result->left_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kChanged)));
|
||||
DiffType::kChanged)));
|
||||
EXPECT_THAT(diff_result->right_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kChanged),
|
||||
DiffType::kChanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kChanged)));
|
||||
DiffType::kChanged)));
|
||||
}
|
||||
|
||||
TEST_F(HloDiffTest, UnmatchedInstructionsMarkAsUnmatched) {
|
||||
|
|
@ -264,19 +264,19 @@ ENTRY entry {
|
|||
EXPECT_THAT(diff_result->left_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kUnmatched),
|
||||
DiffType::kUnmatched),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kUnmatched),
|
||||
DiffType::kUnmatched),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kUnchanged)));
|
||||
DiffType::kUnchanged)));
|
||||
EXPECT_THAT(diff_result->right_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kUnmatched),
|
||||
DiffType::kUnmatched),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.1")),
|
||||
DiffCode::kUnmatched),
|
||||
DiffType::kUnmatched),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kUnchanged)));
|
||||
DiffType::kUnchanged)));
|
||||
}
|
||||
|
||||
TEST_F(HloDiffTest, ShortFormConstantsMatched) {
|
||||
|
|
@ -354,20 +354,20 @@ ENTRY entry {
|
|||
diff_result->left_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "constant.2958")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kUnchanged)));
|
||||
DiffType::kUnchanged)));
|
||||
EXPECT_THAT(
|
||||
diff_result->right_diff_codes,
|
||||
UnorderedElementsAre(
|
||||
Pair(Pointee(Property(&HloInstruction::name, "constant.2958")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "parameter.0")),
|
||||
DiffCode::kUnchanged),
|
||||
DiffType::kUnchanged),
|
||||
Pair(Pointee(Property(&HloInstruction::name, "add.0")),
|
||||
DiffCode::kUnchanged)));
|
||||
DiffType::kUnchanged)));
|
||||
}
|
||||
|
||||
TEST_F(HloDiffTest, DiffResultToAndFromProtoWorks) {
|
||||
|
|
|
|||
|
|
@ -123,16 +123,16 @@ struct DiffFingerprint {
|
|||
|
||||
DiffFingerprint ComputationDiffFingerprint(
|
||||
const xla::HloComputation* computation,
|
||||
const absl::flat_hash_map<const HloInstruction*, DiffCode>& diff_codes) {
|
||||
const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes) {
|
||||
absl::flat_hash_map<const HloInstruction*, uint64_t> subgraph_fingerprint;
|
||||
bool all_unchanged = true;
|
||||
for (auto* instruction : computation->MakeInstructionPostOrder()) {
|
||||
uint64_t fp = static_cast<uint64_t>(instruction->opcode());
|
||||
uint64_t diff_type_fp = DiffCode::kUnchanged;
|
||||
uint64_t diff_type_fp = DiffType::kUnchanged;
|
||||
if (auto it = diff_codes.find(instruction); it != diff_codes.end()) {
|
||||
diff_type_fp = it->second;
|
||||
}
|
||||
all_unchanged = all_unchanged && (diff_type_fp == DiffCode::kUnchanged);
|
||||
all_unchanged = all_unchanged && (diff_type_fp == DiffType::kUnchanged);
|
||||
fp = tsl::FingerprintCat64(fp, diff_type_fp);
|
||||
for (const HloInstruction* operand : instruction->operands()) {
|
||||
fp = tsl::FingerprintCat64(fp, subgraph_fingerprint.at(operand));
|
||||
|
|
@ -262,7 +262,7 @@ std::vector<ComputationDiffPattern> FindComputationDiffPatterns(
|
|||
// Summarizes all computations in the given graph.
|
||||
ComputationSummaryMap SummarizeAllComputationsInGraph(
|
||||
const HloModule& module, const InstructionBimap& mapping,
|
||||
const absl::flat_hash_map<const HloInstruction*, DiffCode>& diff_codes,
|
||||
const absl::flat_hash_map<const HloInstruction*, DiffType>& diff_codes,
|
||||
DiffSide side) {
|
||||
ComputationSummaryMap result;
|
||||
for (const HloComputation* computation : module.computations()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user