PiperOrigin-RevId: 822586242
This commit is contained in:
Dimitar (Mitko) Asenov 2025-10-22 07:46:05 -07:00 committed by TensorFlower Gardener
parent 94d00be0e6
commit bbea04967a
7 changed files with 153 additions and 44 deletions

View File

@ -239,6 +239,28 @@ void HloComputation::ClearCalledComputations() {
CHECK(callee_computations_.empty());
}
void HloComputation::SetInstruction(HloInstruction* instruction,
InstructionType type) {
static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1,
"HloInstruction should be aligned as a QWORD");
DCHECK(type != InstructionType::kUnset)
<< "Set instruction must be called with a valid type, not kUnset.";
DCHECK(instruction_type() == InstructionType::kUnset ||
instruction_type() == type)
<< "Unexpected instruction type. Current type is "
<< static_cast<int>(instruction_type()) << " and it cannot be reset to "
<< static_cast<int>(type);
// If `instruction` is nullptr, we need to preserve the existing type.
if (instruction == nullptr) {
type = instruction_type();
}
instruction_and_type_ =
reinterpret_cast<uintptr_t>(instruction) | static_cast<uintptr_t>(type);
}
HloInstruction* HloComputation::AddInstruction(
std::unique_ptr<HloInstruction> instruction, absl::string_view new_name) {
CHECK(instruction->opcode() != HloOpcode::kParameter)
@ -1422,6 +1444,10 @@ HloComputation::CreateFromProto(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*preserve_instruction_ids=*/true));
computation->SetUniqueIdHelper(proto.id());
if (proto.is_fusion_computation()) {
computation->instruction_and_type_ =
static_cast<uintptr_t>(InstructionType::kFusion);
}
if (!proto.execution_thread().empty()) {
computation->SetExecutionThread(proto.execution_thread());
}

View File

@ -208,6 +208,30 @@ class HloComputation {
~HloComputation();
enum class InstructionType : uint8_t {
kUnset,
// This computation is a fusion computation. A fusion computation ordinarily
// also has a non-null instruction. However, if a fusion instruction
// is removed during compilation, the fusion computation becomes
// unreachable, and its instruction is set to null. We still need to regard
// such computations as fusion computations for HLO scheduling purposes.
kFusion,
// Last Value for range checking.
kLast = kFusion,
};
static_assert(static_cast<int>(InstructionType::kUnset) == 0,
"kUnset must be 0.");
InstructionType instruction_type() const {
return static_cast<InstructionType>(instruction_and_type_ &
kInstructionTypeMask);
}
HloInstruction* instruction() const {
DCHECK(instruction_type() <= InstructionType::kLast);
return reinterpret_cast<HloInstruction*>(instruction_and_type_ &
~kInstructionTypeMask);
}
// Add an instruction to the computation. The computation takes ownership of
// the instruction.
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
@ -789,30 +813,23 @@ class HloComputation {
bool HasSideEffect() const;
// Returns if this computation is a fusion computation.
// Do not use this method to determine if fusion_instruction_ != nullptr.
// Instead, directly do: FusionInstruction() != nullptr
bool IsFusionComputation() const {
// TODO(b/418034360): There should be at most one fusion instruction calling
// a fusion computation. Assert this and fix all related tests.
return !caller_instructions(HloOpcode::kFusion).empty();
return instruction_type() == InstructionType::kFusion;
}
// Returns if this computation is the entry computation of the module.
bool IsEntryComputation() const;
// Returns if this computation is dead. A computation is dead if it is not
// the entry computation and it is not called by any other computation.
bool IsDeadComputation() const {
return !IsEntryComputation() && caller_computations().empty();
}
// Returns the owning fusion instruction, or nullptr if this is not a fusion
// computation. Note that this is just one of the fusion instructions that
// calls this computation, there may be more than one callers.
//
// TODO(b/418034360): There should be at most one fusion instruction calling
// a fusion computation. Assert this and fix all related tests.
// computation.
HloInstruction* FusionInstruction() const {
auto callers = caller_instructions(HloOpcode::kFusion);
return callers.empty() ? nullptr : callers.front();
return instruction_type() == InstructionType::kFusion ? instruction()
: nullptr;
}
void SetFusionInstruction(HloInstruction* fusion_instruction) {
SetInstruction(fusion_instruction, InstructionType::kFusion);
}
// Returns if this computation is an async computation.
@ -1005,6 +1022,8 @@ class HloComputation {
absl::Status RemoveInstructionImpl(HloInstruction* instruction,
bool ignore_safety_check);
void SetInstruction(HloInstruction* instruction, InstructionType type);
// Private, because only HloModule should be able to set the parent.
// We maintain the invariant that a computation has a parent() if and only if
// the computation has been added to a module. Accordingly, the only way to
@ -1039,6 +1058,10 @@ class HloComputation {
// Module containing this computation.
HloModule* parent_ = nullptr;
// Contains HloInstruction* and its type.
// The respective type in the least significant three bits.
uintptr_t instruction_and_type_ = 0;
// Contains an HloInstruction* or an absl::flat_hash_map<HloInstruction*,
// /*count=*/int> in the high bits and a CallersType in the least significant
// bit.

View File

@ -222,7 +222,8 @@ static constexpr uintptr_t kInstructionTypeMask = 0b111;
// HLO is pure (mostly). It has no concept of mutable state. Instead, data
// values are produced by one HLO and flow into consumers across dependency
// edges.
class HloInstruction {
// Alignment must be explicitly specified due to ARM 32 platforms.
class alignas(kInstructionTypeMask + 1) HloInstruction {
public:
// A fusion node computes the same value a call to its fusion computation
// would compute. However, the choice of fusion kind dictates codegen

View File

@ -1979,6 +1979,10 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation(
auto* new_computation = CHECK_NOTNULL(instruction_to_append->GetModule())
->AddEmbeddedComputation(builder.Build());
AppendComputation(new_computation);
if (opcode() == HloOpcode::kFusion) {
new_computation->SetFusionInstruction(this);
}
clone = called_computation_root();
} else {
// When add_output is false, instruction_to_append is necessarily an
@ -2209,6 +2213,31 @@ HloFusionInstruction::HloFusionInstruction(
: HloCallableInstruction(HloOpcode::kFusion, shape, operands,
fusion_computation, prefix),
fusion_kind_(fusion_kind) {
fusion_computation->SetFusionInstruction(this);
}
HloFusionInstruction::~HloFusionInstruction() {
ClearFusionComputationInstruction();
}
void HloFusionInstruction::ClearFusionComputationInstruction() {
// Each fusion calls a single computation, but we use called_computations()
// instead of fused_instructions_computation(), because the order in which
// things get destructed can vary; the fusion computation's back-pointer may
// already be null, which violates a check in
// fused_instructions_computation.
for (HloComputation* computation : called_computations()) {
// Some passes that rewrite fusions may reassign a fusion computation to a
// different fusion instruction as this instruction gets destructed.
if (computation->FusionInstruction() == this) {
computation->SetFusionInstruction(nullptr);
}
}
}
void HloFusionInstruction::ClearCalledComputations() {
ClearFusionComputationInstruction();
HloInstruction::ClearCalledComputations();
}
HloInstruction*
@ -2464,7 +2493,11 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
HloComputation* HloFusionInstruction::fused_instructions_computation() const {
CHECK_EQ(called_computations().size(), 1);
return called_computations().front();
auto* fused_instructions_computation = called_computations().front();
CHECK(fused_instructions_computation->IsFusionComputation())
<< "Computation " << fused_instructions_computation->name()
<< " is not a fusion kind";
return fused_instructions_computation;
}
HloInstruction* HloFusionInstruction::fused_expression_root() const {

View File

@ -1495,6 +1495,14 @@ class HloFusionInstruction : public HloCallableInstruction {
HloComputation* fusion_computation,
absl::string_view prefix = "");
~HloFusionInstruction() override;
void ClearCalledComputations() override;
// When a fusion instruction is being destructed, clear the back pointer of
// its fusion computation, to avoid referencing freed memory.
void ClearFusionComputationInstruction();
// Clones the given instruction_to_append and inserts the clone into this
// callable instruction.
HloInstruction* CloneAndAppendInstructionIntoCalledComputation(

View File

@ -15,7 +15,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_schedule.h"
#include <algorithm>
#include <cstdint>
#include <ostream>
#include <queue>
@ -343,25 +342,7 @@ absl::Status HloSchedule::Verify() const {
sequence_num_by_execution_threads) {
std::vector<HloComputation*> nonfusion_computations =
module_->MakeNonfusionComputations({thread_name});
// TODO(dasenov): Replace with std::erase_if after XLA uses C++20.
auto remove_it = std::remove_if(nonfusion_computations.begin(),
nonfusion_computations.end(),
[](const HloComputation* computation) {
return computation->IsDeadComputation();
});
nonfusion_computations.erase(remove_it, nonfusion_computations.end());
// It's possible to have more sequences than non_fusion_computations.
// This is because in some cases computations that have schedules are
// actually dead. The important thing to check is that each live non-fusion
// computation has a sequence.
//
// TODO(b/418034360): Consider strenghtening this check to equality. That
// would require cleaning up dead computations and/or recomputing the
// schedule in a number of tests. In its present state (using less or equal)
// this check is subsumed by the next one.
TF_RET_CHECK(nonfusion_computations.size() <= sequence_size)
TF_RET_CHECK(nonfusion_computations.size() == sequence_size)
<< "For thread " << thread_name << ", schedule has " << sequence_size
<< " sequences, but module has " << nonfusion_computations.size()
<< " non-fusion computations for thread " << thread_name;

View File

@ -88,6 +88,28 @@ absl::StatusOr<bool> FlattenNode(const CallGraphNode& node) {
return changed;
}
// Annotates flatten computations with callee instruction types.
absl::Status AnnotateNode(const CallGraphNode& node) {
for (auto& callsite : node.callsites()) {
HloInstruction* instruction = callsite.instruction();
if (instruction->opcode() == HloOpcode::kFusion) {
for (HloComputation* computation : instruction->called_computations()) {
computation->SetFusionInstruction(instruction);
}
}
}
// Correctly handle dead code: if a fusion computation is no longer used, it
// should not have a fusion instruction set.
if (node.callers().empty() &&
node.computation()->FusionInstruction() != nullptr) {
node.computation()->SetFusionInstruction(nullptr);
}
return absl::OkStatus();
}
} // namespace
absl::StatusOr<bool> FlattenCallGraph::Run(
@ -95,14 +117,29 @@ absl::StatusOr<bool> FlattenCallGraph::Run(
const absl::flat_hash_set<absl::string_view>& execution_threads) {
XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString());
// Flatten original call graph.
std::unique_ptr<CallGraph> call_graph =
CallGraph::Build(module, execution_threads);
TF_ASSIGN_OR_RETURN(bool changed,
call_graph->VisitNodesWithReturn(FlattenNode));
bool changed = false;
{ // Flatten original call graph.
std::unique_ptr<CallGraph> call_graph =
CallGraph::Build(module, execution_threads);
TF_ASSIGN_OR_RETURN(bool flattened,
call_graph->VisitNodesWithReturn(FlattenNode));
changed |= flattened;
}
if (!changed) {
return false;
}
// TODO(b/418034360): Remove this step once the fusion instruction is
// automatically maintained.
{ // Annotate flattened computations with callee types.
std::unique_ptr<CallGraph> call_graph =
CallGraph::Build(module, execution_threads);
TF_RETURN_IF_ERROR(call_graph->VisitNodes(AnnotateNode));
}
XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString());
return changed;
return true;
}
} // namespace xla