mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
If device placement annotations are found inside host computations (as a result of nested host computations), hoist them up the call stack. If any unsupported cases or inconsistencies are detected, an error will be returned to the user.
This allows JAX's migration from their previous `compute_on` API to the new (currently named `compute_on2`) API. PiperOrigin-RevId: 825177029
This commit is contained in:
parent
768e653c9c
commit
63a9d0d1f8
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
|
|
@ -151,6 +152,304 @@ absl::StatusOr<bool> OffloadHostInstructions(
|
|||
|
||||
return modified;
|
||||
}
|
||||
|
||||
std::string GetDevicePlacement(const HloInstruction* instr) {
|
||||
CHECK(instr->IsCustomCall(memory_annotations::kDevicePlacement))
|
||||
<< "Input " << instr->name() << " must be a device placement annotation";
|
||||
CHECK(instr->has_frontend_attributes())
|
||||
<< "Input " << instr->name() << " must have frontend attributes";
|
||||
const auto& frontend_attribute_map = instr->frontend_attributes().map();
|
||||
auto buffer_placement_it =
|
||||
frontend_attribute_map.find(kXlaBufferPlacementAttr);
|
||||
CHECK(buffer_placement_it != frontend_attribute_map.end())
|
||||
<< "Input " << instr->name()
|
||||
<< " must have a buffer placement frontend attribute";
|
||||
return buffer_placement_it->second;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<HloInstruction*> CollectAllowedDevicePlacementAnnotations(
|
||||
const HloComputation* computation) {
|
||||
// Collect a list of allowed annotations. We only expect annotations in one of
|
||||
// two locations in host computations currently:
|
||||
// 1. The ROOT instruction, if the computation returns a single value.
|
||||
// 2. The items feeding into the ROOT tuple instruction, if the computation
|
||||
// returns a tuple.
|
||||
absl::flat_hash_set<HloInstruction*> allowed_device_placement_annotations;
|
||||
HloInstruction* root_instr = computation->root_instruction();
|
||||
if (root_instr->opcode() == HloOpcode::kTuple) {
|
||||
// Is a tuple
|
||||
for (int64_t i = 0; i < root_instr->operand_count(); ++i) {
|
||||
HloInstruction* operand = root_instr->mutable_operand(i);
|
||||
if (operand->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
allowed_device_placement_annotations.insert(operand);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Is not a tuple
|
||||
if (root_instr->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
allowed_device_placement_annotations.insert(root_instr);
|
||||
}
|
||||
}
|
||||
return allowed_device_placement_annotations;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::vector<HloInstruction*>>
|
||||
CheckRemainingDevicePlacementAnnotations(
|
||||
const HloComputation* computation,
|
||||
const absl::flat_hash_set<HloInstruction*>&
|
||||
allowed_device_placement_annotations) {
|
||||
// Look for annotations which are not in the allowed set. If any annotation is
|
||||
// redundant, return it in a list so that the caller of this function can
|
||||
// remove it. Any other annotation is an error.
|
||||
std::vector<HloInstruction*> redundant_annotations;
|
||||
for (HloInstruction* instr : computation->instructions()) {
|
||||
if (instr->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
if (allowed_device_placement_annotations.contains(instr)) {
|
||||
continue;
|
||||
}
|
||||
const std::string device_placement = GetDevicePlacement(instr);
|
||||
if (device_placement == memory_annotations::kMemoryTargetPinnedHost ||
|
||||
device_placement == memory_annotations::kMemoryTargetUnpinnedHost) {
|
||||
// An annotation in host computation annotating the buffer to be on the
|
||||
// host is redundant.
|
||||
redundant_annotations.push_back(instr);
|
||||
} else {
|
||||
// An annotation in host computation annotating the buffer to be
|
||||
// somewhere other than the host is not allowed.
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Host computation %s contains a device placement "
|
||||
"annotation %s that is not allowed.",
|
||||
computation->name(), instr->ToString()));
|
||||
}
|
||||
}
|
||||
}
|
||||
return redundant_annotations;
|
||||
}
|
||||
|
||||
// Returns true if any redundant annotations were removed.
|
||||
absl::StatusOr<bool> CleanUpHostComputationDevicePlacementAnnotations(
|
||||
const HloComputation* computation) {
|
||||
const absl::flat_hash_set<HloInstruction*>
|
||||
allowed_device_placement_annotations =
|
||||
CollectAllowedDevicePlacementAnnotations(computation);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const std::vector<HloInstruction*> redundant_device_placement_annotations,
|
||||
CheckRemainingDevicePlacementAnnotations(
|
||||
computation, allowed_device_placement_annotations));
|
||||
|
||||
// Remove redundant annotations
|
||||
for (HloInstruction* redundant_annotation :
|
||||
redundant_device_placement_annotations) {
|
||||
VLOG(1) << "Removing redundant annotation: "
|
||||
<< redundant_annotation->ToString();
|
||||
CHECK_EQ(redundant_annotation->operand_count(), 1)
|
||||
<< "A device placement annotation must have exactly one operand.";
|
||||
for (HloInstruction* user : redundant_annotation->users()) {
|
||||
for (int64_t operand_index :
|
||||
user->operand_indices(redundant_annotation)) {
|
||||
TF_RETURN_IF_ERROR(user->ReplaceOperandWith(
|
||||
operand_index, redundant_annotation->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(redundant_annotation->parent()->RemoveInstruction(
|
||||
redundant_annotation));
|
||||
}
|
||||
|
||||
return !redundant_device_placement_annotations.empty();
|
||||
}
|
||||
|
||||
bool DevicePlacementMemorySpaceIsSame(const HloInstruction* a,
|
||||
const HloInstruction* b) {
|
||||
CHECK(a->IsCustomCall(memory_annotations::kDevicePlacement))
|
||||
<< "Input a: " << a->name() << " must be a device placement annotation";
|
||||
CHECK(b->IsCustomCall(memory_annotations::kDevicePlacement))
|
||||
<< "Input b: " << b->name() << " must be a device placement annotation";
|
||||
return GetDevicePlacement(a) == GetDevicePlacement(b);
|
||||
}
|
||||
|
||||
absl::Status CloneAnnotationToDestination(
|
||||
HloComputation* destination_computation,
|
||||
HloInstruction* destination_computation_caller_instruction,
|
||||
const HloInstruction* original_annotation,
|
||||
HloInstruction* destination_instruction) {
|
||||
HloInstruction* moved_annotation = destination_computation->AddInstruction(
|
||||
original_annotation->CloneWithNewOperands(original_annotation->shape(),
|
||||
{destination_instruction},
|
||||
"move_to_caller"));
|
||||
|
||||
bool used_new_annotation = false;
|
||||
for (HloInstruction* destination_user : destination_instruction->users()) {
|
||||
if (destination_user == moved_annotation) {
|
||||
// Do not replace the annotation with itself.
|
||||
continue;
|
||||
}
|
||||
if (destination_user->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
// The destination already has an annotation.
|
||||
if (!DevicePlacementMemorySpaceIsSame(original_annotation,
|
||||
destination_user)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Found conflicting host computation output memory "
|
||||
"space. Call %s wants output memory space %s but "
|
||||
"call %s wants output memory space %s",
|
||||
original_annotation->operand(0)->name(),
|
||||
GetDevicePlacement(original_annotation),
|
||||
destination_computation_caller_instruction->name(),
|
||||
GetDevicePlacement(destination_user)));
|
||||
}
|
||||
// Annotation already exists, nothing to do.
|
||||
continue;
|
||||
}
|
||||
for (int64_t operand_index :
|
||||
destination_user->operand_indices(destination_instruction)) {
|
||||
TF_RETURN_IF_ERROR(destination_user->ReplaceOperandWith(
|
||||
operand_index, moved_annotation));
|
||||
}
|
||||
used_new_annotation = true;
|
||||
}
|
||||
|
||||
// All the places where this annotation would be placed already have this
|
||||
// exact annotation.
|
||||
if (!used_new_annotation) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
destination_computation->RemoveInstruction(moved_annotation));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<bool> MoveAnnotationsToCallerTuple(
|
||||
HloComputation* host_computation) {
|
||||
bool changed = false;
|
||||
for (int64_t operand_index = 0;
|
||||
operand_index < host_computation->root_instruction()->operand_count();
|
||||
++operand_index) {
|
||||
HloInstruction* root_operand =
|
||||
host_computation->root_instruction()->mutable_operand(operand_index);
|
||||
if (!root_operand->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
// Instruction is not a device placement annotation; nothing to do.
|
||||
continue;
|
||||
}
|
||||
// Root is a device placement annotation.
|
||||
CHECK_EQ(root_operand->operand_count(), 1)
|
||||
<< "A device placement annotation must have exactly one operand.";
|
||||
|
||||
// Clone the annotation to each of the callers.
|
||||
for (HloInstruction* caller_instruction :
|
||||
host_computation->caller_instructions()) {
|
||||
HloComputation* caller_computation = caller_instruction->parent();
|
||||
for (HloInstruction* caller_user_gte : caller_instruction->users()) {
|
||||
if (caller_user_gte->opcode() != HloOpcode::kGetTupleElement) {
|
||||
return absl::UnimplementedError(
|
||||
"When moving device placement annotations out of a host "
|
||||
"computation, the tuple is used by something other than a "
|
||||
"get-tuple-element. This is currently not supported.");
|
||||
}
|
||||
if (caller_user_gte->tuple_index() != operand_index) {
|
||||
// This get-tuple-element is getting a different index than the one we
|
||||
// are currently looking at.
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CloneAnnotationToDestination(caller_computation, caller_instruction,
|
||||
root_operand, caller_user_gte));
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(host_computation->root_instruction()->ReplaceOperandWith(
|
||||
operand_index, root_operand->mutable_operand(0)));
|
||||
TF_RETURN_IF_ERROR(host_computation->RemoveInstruction(root_operand));
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
absl::StatusOr<bool> MoveAnnotationToCallerNonTuple(
|
||||
HloComputation* host_computation) {
|
||||
HloInstruction* root_instr = host_computation->root_instruction();
|
||||
if (!root_instr->IsCustomCall(memory_annotations::kDevicePlacement)) {
|
||||
// Root is not a device placement annotation; nothing to do.
|
||||
return false;
|
||||
}
|
||||
// Root is a device placement annotation.
|
||||
CHECK_EQ(root_instr->operand_count(), 1)
|
||||
<< "A device placement annotation must have exactly one operand.";
|
||||
|
||||
// Clone the annotation to each of the callers.
|
||||
for (HloInstruction* caller_instruction :
|
||||
host_computation->caller_instructions()) {
|
||||
HloComputation* caller_computation = caller_instruction->parent();
|
||||
TF_RETURN_IF_ERROR(
|
||||
CloneAnnotationToDestination(caller_computation, caller_instruction,
|
||||
root_instr, caller_instruction));
|
||||
}
|
||||
|
||||
// Remove the annotation from inside this computation.
|
||||
host_computation->set_root_instruction(root_instr->mutable_operand(0));
|
||||
TF_RETURN_IF_ERROR(host_computation->RemoveInstruction(root_instr));
|
||||
return true;
|
||||
}
|
||||
|
||||
// Move host device placement annotations out of this computation to the calling
|
||||
// computation.
|
||||
absl::StatusOr<bool> MoveAnnotationsToCaller(HloComputation* computation) {
|
||||
bool changed = false;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool cleaned_up,
|
||||
CleanUpHostComputationDevicePlacementAnnotations(computation));
|
||||
changed = changed || cleaned_up;
|
||||
// All annotations at this point are valid.
|
||||
if (computation->root_instruction()->opcode() == HloOpcode::kTuple) {
|
||||
// When the computation returns a tuple, the annotation is on the operands
|
||||
// of the root tuple.
|
||||
TF_ASSIGN_OR_RETURN(bool moved, MoveAnnotationsToCallerTuple(computation));
|
||||
changed = changed || moved;
|
||||
} else {
|
||||
// When the computation returns a single value, the annotation is the root
|
||||
// instruction.
|
||||
TF_ASSIGN_OR_RETURN(bool moved,
|
||||
MoveAnnotationToCallerNonTuple(computation));
|
||||
changed = changed || moved;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
absl::StatusOr<bool> RemoveDevicePlacementAnnotationsFromHostComputations(
|
||||
HloModule* module) {
|
||||
// The only time we currently find device placement annotations in host
|
||||
// computations are when the host computation calls another host computation
|
||||
// and that called host computation has an output memory space annotated. That
|
||||
// output memory space annotation is usually on the users of the host call (or
|
||||
// users of the get-tuple-elements if the call returns a tuple).
|
||||
//
|
||||
// Visit host computations in post-order. We will push annotations out of host
|
||||
// computations into their callers.
|
||||
std::vector<HloComputation*> host_computations;
|
||||
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
// Check if this computation is a host computation.
|
||||
for (const HloInstruction* caller_instruction :
|
||||
computation->caller_instructions()) {
|
||||
if (caller_instruction->has_frontend_attributes()) {
|
||||
FrontendAttributes frontend_attributes =
|
||||
caller_instruction->frontend_attributes();
|
||||
if (frontend_attributes.map().contains(kXlaComputeTypeAttr) &&
|
||||
frontend_attributes.map().at(kXlaComputeTypeAttr) ==
|
||||
kXlaComputeTypeHost) {
|
||||
// The computation is a host computation.
|
||||
host_computations.push_back(computation);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
for (HloComputation* computation : host_computations) {
|
||||
TF_ASSIGN_OR_RETURN(bool moved, MoveAnnotationsToCaller(computation));
|
||||
changed = changed || moved;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/ absl::StatusOr<HloCallInstruction*>
|
||||
|
|
@ -321,6 +620,16 @@ absl::StatusOr<bool> HloHostDeviceTypeCallWrapper::Run(
|
|||
return false;
|
||||
}
|
||||
|
||||
// Before any other passes run, move device placement annotations out of host
|
||||
// computations.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool modified,
|
||||
RemoveDevicePlacementAnnotationsFromHostComputations(module));
|
||||
// At this point, this pass will always modify the module. The return value of
|
||||
// this function, which indicates whether the module was modified, is not
|
||||
// useful.
|
||||
(void)modified;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
AnnotateHostComputeOffload().Run(module, execution_threads).status());
|
||||
TF_RETURN_IF_ERROR(CallInliner().Run(module, execution_threads).status());
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user