mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
parent
2abafe5c4d
commit
c28d80ae66
26
third_party/xla/xla/hlo/ir/hlo_computation.cc
vendored
26
third_party/xla/xla/hlo/ir/hlo_computation.cc
vendored
|
|
@ -239,28 +239,6 @@ void HloComputation::ClearCalledComputations() {
|
|||
CHECK(callee_computations_.empty());
|
||||
}
|
||||
|
||||
void HloComputation::SetInstruction(HloInstruction* instruction,
|
||||
InstructionType type) {
|
||||
static_assert(alignof(HloInstruction) == kInstructionTypeMask + 1,
|
||||
"HloInstruction should be aligned as a QWORD");
|
||||
|
||||
DCHECK(type != InstructionType::kUnset)
|
||||
<< "Set instruction must be called with a valid type, not kUnset.";
|
||||
DCHECK(instruction_type() == InstructionType::kUnset ||
|
||||
instruction_type() == type)
|
||||
<< "Unexpected instruction type. Current type is "
|
||||
<< static_cast<int>(instruction_type()) << " and it cannot be reset to "
|
||||
<< static_cast<int>(type);
|
||||
|
||||
// If `instruction` is nullptr, we need to preserve the existing type.
|
||||
if (instruction == nullptr) {
|
||||
type = instruction_type();
|
||||
}
|
||||
|
||||
instruction_and_type_ =
|
||||
reinterpret_cast<uintptr_t>(instruction) | static_cast<uintptr_t>(type);
|
||||
}
|
||||
|
||||
HloInstruction* HloComputation::AddInstruction(
|
||||
std::unique_ptr<HloInstruction> instruction, absl::string_view new_name) {
|
||||
CHECK(instruction->opcode() != HloOpcode::kParameter)
|
||||
|
|
@ -1409,10 +1387,6 @@ HloComputation::CreateFromProto(
|
|||
auto computation = absl::WrapUnique(new HloComputation(
|
||||
proto.name(), parameter_count, &instructions, root, /*from_proto=*/true));
|
||||
computation->SetUniqueIdHelper(proto.id());
|
||||
if (proto.is_fusion_computation()) {
|
||||
computation->instruction_and_type_ =
|
||||
static_cast<uintptr_t>(InstructionType::kFusion);
|
||||
}
|
||||
if (!proto.execution_thread().empty()) {
|
||||
computation->SetExecutionThread(proto.execution_thread());
|
||||
}
|
||||
|
|
|
|||
57
third_party/xla/xla/hlo/ir/hlo_computation.h
vendored
57
third_party/xla/xla/hlo/ir/hlo_computation.h
vendored
|
|
@ -208,30 +208,6 @@ class HloComputation {
|
|||
|
||||
~HloComputation();
|
||||
|
||||
enum class InstructionType : uint8_t {
|
||||
kUnset,
|
||||
// This computation is a fusion computation. A fusion computation ordinarily
|
||||
// also has a non-null instruction. However, if a fusion instruction
|
||||
// is removed during compilation, the fusion computation becomes
|
||||
// unreachable, and its instruction is set to null. We still need to regard
|
||||
// such computations as fusion computations for HLO scheduling purposes.
|
||||
kFusion,
|
||||
// Last Value for range checking.
|
||||
kLast = kFusion,
|
||||
};
|
||||
static_assert(static_cast<int>(InstructionType::kUnset) == 0,
|
||||
"kUnset must be 0.");
|
||||
|
||||
InstructionType instruction_type() const {
|
||||
return static_cast<InstructionType>(instruction_and_type_ &
|
||||
kInstructionTypeMask);
|
||||
}
|
||||
|
||||
HloInstruction* instruction() const {
|
||||
DCHECK(instruction_type() <= InstructionType::kLast);
|
||||
return reinterpret_cast<HloInstruction*>(instruction_and_type_ &
|
||||
~kInstructionTypeMask);
|
||||
}
|
||||
// Add an instruction to the computation. The computation takes ownership of
|
||||
// the instruction.
|
||||
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
|
||||
|
|
@ -805,23 +781,30 @@ class HloComputation {
|
|||
bool HasSideEffect() const;
|
||||
|
||||
// Returns if this computation is a fusion computation.
|
||||
// Do not use this method to determine if fusion_instruction_ != nullptr.
|
||||
// Instead, directly do: FusionInstruction() != nullptr
|
||||
bool IsFusionComputation() const {
|
||||
return instruction_type() == InstructionType::kFusion;
|
||||
// TODO(b/418034360): There should be at most one fusion instruction calling
|
||||
// a fusion computation. Assert this and fix all related tests.
|
||||
return !caller_instructions(HloOpcode::kFusion).empty();
|
||||
}
|
||||
|
||||
// Returns if this computation is the entry computation of the module.
|
||||
bool IsEntryComputation() const;
|
||||
|
||||
// Returns the owning fusion instruction, or nullptr if this is not a fusion
|
||||
// computation.
|
||||
HloInstruction* FusionInstruction() const {
|
||||
return instruction_type() == InstructionType::kFusion ? instruction()
|
||||
: nullptr;
|
||||
// Returns if this computation is dead. A computation is dead if it is not
|
||||
// the entry computation and it is not called by any other computation.
|
||||
bool IsDeadComputation() const {
|
||||
return !IsEntryComputation() && caller_computations().empty();
|
||||
}
|
||||
void SetFusionInstruction(HloInstruction* fusion_instruction) {
|
||||
SetInstruction(fusion_instruction, InstructionType::kFusion);
|
||||
|
||||
// Returns the owning fusion instruction, or nullptr if this is not a fusion
|
||||
// computation. Note that this is just one of the fusion instructions that
|
||||
// calls this computation, there may be more than one callers.
|
||||
//
|
||||
// TODO(b/418034360): There should be at most one fusion instruction calling
|
||||
// a fusion computation. Assert this and fix all related tests.
|
||||
HloInstruction* FusionInstruction() const {
|
||||
auto callers = caller_instructions(HloOpcode::kFusion);
|
||||
return callers.empty() ? nullptr : callers.front();
|
||||
}
|
||||
|
||||
// Returns if this computation is an async computation.
|
||||
|
|
@ -1013,8 +996,6 @@ class HloComputation {
|
|||
absl::Status RemoveInstructionImpl(HloInstruction* instruction,
|
||||
bool ignore_safety_check);
|
||||
|
||||
void SetInstruction(HloInstruction* instruction, InstructionType type);
|
||||
|
||||
// Private, because only HloModule should be able to set the parent.
|
||||
// We maintain the invariant that a computation has a parent() if and only if
|
||||
// the computation has been added to a module. Accordingly, the only way to
|
||||
|
|
@ -1049,10 +1030,6 @@ class HloComputation {
|
|||
// Module containing this computation.
|
||||
HloModule* parent_ = nullptr;
|
||||
|
||||
// Contains HloInstruction* and its type.
|
||||
// The respective type in the least significant three bits.
|
||||
uintptr_t instruction_and_type_ = 0;
|
||||
|
||||
// Contains an HloInstruction* or an absl::flat_hash_map<HloInstruction*,
|
||||
// /*count=*/int> in the high bits and a CallersType in the least significant
|
||||
// bit.
|
||||
|
|
|
|||
3
third_party/xla/xla/hlo/ir/hlo_instruction.h
vendored
3
third_party/xla/xla/hlo/ir/hlo_instruction.h
vendored
|
|
@ -222,8 +222,7 @@ static constexpr uintptr_t kInstructionTypeMask = 0b111;
|
|||
// HLO is pure (mostly). It has no concept of mutable state. Instead, data
|
||||
// values are produced by one HLO and flow into consumers across dependency
|
||||
// edges.
|
||||
// Alignment must be explicitly specified due to ARM 32 platforms.
|
||||
class alignas(kInstructionTypeMask + 1) HloInstruction {
|
||||
class HloInstruction {
|
||||
public:
|
||||
// A fusion node computes the same value a call to its fusion computation
|
||||
// would compute. However, the choice of fusion kind dictates codegen
|
||||
|
|
|
|||
35
third_party/xla/xla/hlo/ir/hlo_instructions.cc
vendored
35
third_party/xla/xla/hlo/ir/hlo_instructions.cc
vendored
|
|
@ -1979,10 +1979,6 @@ HloCallableInstruction::CloneAndAppendInstructionIntoCalledComputation(
|
|||
auto* new_computation = CHECK_NOTNULL(instruction_to_append->GetModule())
|
||||
->AddEmbeddedComputation(builder.Build());
|
||||
AppendComputation(new_computation);
|
||||
if (opcode() == HloOpcode::kFusion) {
|
||||
new_computation->SetFusionInstruction(this);
|
||||
}
|
||||
|
||||
clone = called_computation_root();
|
||||
} else {
|
||||
// When add_output is false, instruction_to_append is necessarily an
|
||||
|
|
@ -2213,31 +2209,6 @@ HloFusionInstruction::HloFusionInstruction(
|
|||
: HloCallableInstruction(HloOpcode::kFusion, shape, operands,
|
||||
fusion_computation, prefix),
|
||||
fusion_kind_(fusion_kind) {
|
||||
fusion_computation->SetFusionInstruction(this);
|
||||
}
|
||||
|
||||
HloFusionInstruction::~HloFusionInstruction() {
|
||||
ClearFusionComputationInstruction();
|
||||
}
|
||||
|
||||
void HloFusionInstruction::ClearFusionComputationInstruction() {
|
||||
// Each fusion calls a single computation, but we use called_computations()
|
||||
// instead of fused_instructions_computation(), because the order in which
|
||||
// things get destructed can vary; the fusion computation's back-pointer may
|
||||
// already be null, which violates a check in
|
||||
// fused_instructions_computation.
|
||||
for (HloComputation* computation : called_computations()) {
|
||||
// Some passes that rewrite fusions may reassign a fusion computation to a
|
||||
// different fusion instruction as this instruction gets destructed.
|
||||
if (computation->FusionInstruction() == this) {
|
||||
computation->SetFusionInstruction(nullptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void HloFusionInstruction::ClearCalledComputations() {
|
||||
ClearFusionComputationInstruction();
|
||||
HloInstruction::ClearCalledComputations();
|
||||
}
|
||||
|
||||
HloInstruction*
|
||||
|
|
@ -2493,11 +2464,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
|
|||
|
||||
HloComputation* HloFusionInstruction::fused_instructions_computation() const {
|
||||
CHECK_EQ(called_computations().size(), 1);
|
||||
auto* fused_instructions_computation = called_computations().front();
|
||||
CHECK(fused_instructions_computation->IsFusionComputation())
|
||||
<< "Computation " << fused_instructions_computation->name()
|
||||
<< " is not a fusion kind";
|
||||
return fused_instructions_computation;
|
||||
return called_computations().front();
|
||||
}
|
||||
|
||||
HloInstruction* HloFusionInstruction::fused_expression_root() const {
|
||||
|
|
|
|||
|
|
@ -1495,14 +1495,6 @@ class HloFusionInstruction : public HloCallableInstruction {
|
|||
HloComputation* fusion_computation,
|
||||
absl::string_view prefix = "");
|
||||
|
||||
~HloFusionInstruction() override;
|
||||
|
||||
void ClearCalledComputations() override;
|
||||
|
||||
// When a fusion instruction is being destructed, clear the back pointer of
|
||||
// its fusion computation, to avoid referencing freed memory.
|
||||
void ClearFusionComputationInstruction();
|
||||
|
||||
// Clones the given instruction_to_append and inserts the clone into this
|
||||
// callable instruction.
|
||||
HloInstruction* CloneAndAppendInstructionIntoCalledComputation(
|
||||
|
|
|
|||
21
third_party/xla/xla/hlo/ir/hlo_schedule.cc
vendored
21
third_party/xla/xla/hlo/ir/hlo_schedule.cc
vendored
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "xla/hlo/ir/hlo_schedule.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <ostream>
|
||||
#include <queue>
|
||||
|
|
@ -323,7 +324,25 @@ absl::Status HloSchedule::Verify() const {
|
|||
sequence_num_by_execution_threads) {
|
||||
std::vector<HloComputation*> nonfusion_computations =
|
||||
module_->MakeNonfusionComputations({thread_name});
|
||||
TF_RET_CHECK(nonfusion_computations.size() == sequence_size)
|
||||
|
||||
// TODO(dasenov): Replace with std::erase_if after XLA uses C++20.
|
||||
auto remove_it = std::remove_if(nonfusion_computations.begin(),
|
||||
nonfusion_computations.end(),
|
||||
[](const HloComputation* computation) {
|
||||
return computation->IsDeadComputation();
|
||||
});
|
||||
nonfusion_computations.erase(remove_it, nonfusion_computations.end());
|
||||
|
||||
// It's possible to have more sequences than non_fusion_computations.
|
||||
// This is because in some cases computations that have schedules are
|
||||
// actually dead. The important thing to check is that each live non-fusion
|
||||
// computation has a sequence.
|
||||
//
|
||||
// TODO(b/418034360): Consider strenghtening this check to equality. That
|
||||
// would require cleaning up dead computations and/or recomputing the
|
||||
// schedule in a number of tests. In its present state (using less or equal)
|
||||
// this check is subsumed by the next one.
|
||||
TF_RET_CHECK(nonfusion_computations.size() <= sequence_size)
|
||||
<< "For thread " << thread_name << ", schedule has " << sequence_size
|
||||
<< " sequences, but module has " << nonfusion_computations.size()
|
||||
<< " non-fusion computations for thread " << thread_name;
|
||||
|
|
|
|||
|
|
@ -88,28 +88,6 @@ absl::StatusOr<bool> FlattenNode(const CallGraphNode& node) {
|
|||
return changed;
|
||||
}
|
||||
|
||||
// Annotates flatten computations with callee instruction types.
|
||||
absl::Status AnnotateNode(const CallGraphNode& node) {
|
||||
for (auto& callsite : node.callsites()) {
|
||||
HloInstruction* instruction = callsite.instruction();
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
for (HloComputation* computation : instruction->called_computations()) {
|
||||
computation->SetFusionInstruction(instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Correctly handle dead code: if a fusion computation is no longer used, it
|
||||
// should not have a fusion instruction set.
|
||||
if (node.callers().empty() &&
|
||||
node.computation()->FusionInstruction() != nullptr) {
|
||||
node.computation()->SetFusionInstruction(nullptr);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::StatusOr<bool> FlattenCallGraph::Run(
|
||||
|
|
@ -117,29 +95,14 @@ absl::StatusOr<bool> FlattenCallGraph::Run(
|
|||
const absl::flat_hash_set<absl::string_view>& execution_threads) {
|
||||
XLA_VLOG_LINES(3, "Before flatten call graph:\n" + module->ToString());
|
||||
|
||||
bool changed = false;
|
||||
{ // Flatten original call graph.
|
||||
// Flatten original call graph.
|
||||
std::unique_ptr<CallGraph> call_graph =
|
||||
CallGraph::Build(module, execution_threads);
|
||||
TF_ASSIGN_OR_RETURN(bool flattened,
|
||||
TF_ASSIGN_OR_RETURN(bool changed,
|
||||
call_graph->VisitNodesWithReturn(FlattenNode));
|
||||
changed |= flattened;
|
||||
}
|
||||
|
||||
if (!changed) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO(b/418034360): Remove this step once the fusion instruction is
|
||||
// automatically maintained.
|
||||
{ // Annotate flattened computations with callee types.
|
||||
std::unique_ptr<CallGraph> call_graph =
|
||||
CallGraph::Build(module, execution_threads);
|
||||
TF_RETURN_IF_ERROR(call_graph->VisitNodes(AnnotateNode));
|
||||
}
|
||||
|
||||
XLA_VLOG_LINES(3, "After flatten call graph:\n" + module->ToString());
|
||||
return true;
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user