Add an overload for SpmdPartitioner::SetPartitionedHlo to avoid unnecessary lambda functions.

PiperOrigin-RevId: 825819367
This commit is contained in:
Zixuan Jiang 2025-10-29 19:58:18 -07:00 committed by TensorFlower Gardener
parent 512f1e48cb
commit ba10feaa24
6 changed files with 89 additions and 114 deletions

View File

@ -341,7 +341,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
};
HloInstruction* rotated0 = rotate_with_padding(amount);
if (right_padding == 0) {
SetPartitionedHlo(hlo, [&] { return rotated0; });
SetPartitionedHlo(hlo, rotated0);
return absl::OkStatus();
}
@ -374,10 +374,9 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
selection_boundary, Comparison::Direction::kLt));
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateTernary(
rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
});
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateTernary(
rotated0->shape(), HloOpcode::kSelect, pred,
rotated1, rotated0)));
return absl::OkStatus();
}
@ -405,7 +404,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
input->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
auto copy = b_.AddInstruction(
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
SetPartitionedHlo(hlo, [&] { return copy; });
SetPartitionedHlo(hlo, copy);
return absl::OkStatus();
}
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
@ -416,7 +415,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
CHECK(ShapeUtil::Compatible(
copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
SetPartitionedHlo(hlo, [&] { return copy; });
SetPartitionedHlo(hlo, copy);
return absl::OkStatus();
}

View File

@ -4349,7 +4349,7 @@ absl::Status SpmdPartitioningVisitor::HandleDotHelper(
num_partitions_, create_sharded_dot, conv_window, module_,
hlo, options_, &b_, &windowed_dot_general_loops_, this));
}
SetPartitionedHlo(hlo, [partitioned_dot] { return partitioned_dot; });
SetPartitionedHlo(hlo, partitioned_dot);
return absl::OkStatus();
}

View File

@ -426,10 +426,7 @@ absl::Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
partitioned_input.state().next_channel_id, module_,
partitioned_input.state().b);
result->set_sharding(hlo->sharding());
auto partitioned_fft =
PartitionedHlo(result, hlo->shape(), partitioned_input.state());
SetPartitionedHlo(hlo, std::move(partitioned_fft));
SetPartitionedHlo(hlo, result);
return absl::OkStatus();
}

View File

@ -1009,8 +1009,7 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
PartitionGather(gather, operand, indices, gather->shape(),
gather->sharding(), absl::MakeConstSpan(batch_dims),
gather->gather_slice_sizes(), this));
SetPartitionedHlo(gather, PartitionedHlo(pgather, gather->shape(),
MakePartitioningState()));
SetPartitionedHlo(gather, pgather);
return absl::OkStatus();
}
@ -1904,8 +1903,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
if (!pscatter) {
return DefaultAction(hlo);
}
SetPartitionedHlo(scatter, PartitionedHlo(pscatter, scatter->shape(),
MakePartitioningState()));
SetPartitionedHlo(scatter, pscatter);
return absl::OkStatus();
}

View File

@ -2782,10 +2782,10 @@ absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
new_operands.push_back(
GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
SetPartitionedHlo(
hlo,
b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)));
return absl::OkStatus();
}
@ -2908,7 +2908,7 @@ absl::Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] { return final_operand; });
SetPartitionedHlo(hlo, final_operand);
return absl::OkStatus();
}
@ -2933,9 +2933,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
}
auto clone = b_.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(sharding);
SetPartitionedHlo(
hlo, PartitionedHlo(clone, hlo->shape(), MakePartitioningState()));
SetPartitionedHlo(hlo, clone);
return absl::OkStatus();
}
// Special handling for sort in TopK when first operand partitioined at
@ -3126,10 +3124,10 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
for (HloInstruction* operand : hlo->operands()) {
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
});
SetPartitionedHlo(
hlo,
b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)));
return absl::OkStatus();
}
@ -3149,10 +3147,10 @@ absl::Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
auto operand = GetPartitionedHlo(hlo->operand(0))
.Reshard(desired_operand_sharding)
.hlo();
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
});
SetPartitionedHlo(
hlo,
b_.AddInstruction(hlo->CloneWithNewOperands(
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})));
return absl::OkStatus();
}
@ -3199,7 +3197,7 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
PartitionedHlo reshard_reshape =
PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
.Reshard(sharding);
SetPartitionedHlo(hlo, [&] { return reshard_reshape.hlo(); });
SetPartitionedHlo(hlo, reshard_reshape.hlo());
if (sharding_pairs.size() == 2 &&
sharding_pairs[1].first == operand.sharding() &&
@ -3456,7 +3454,7 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
};
TF_ASSIGN_OR_RETURN(HloInstruction * partitioned,
recursive_shard(operand, sharding, hlo->shape()));
SetPartitionedHlo(hlo, [&] { return partitioned; });
SetPartitionedHlo(hlo, partitioned);
return absl::OkStatus();
}
@ -3545,11 +3543,9 @@ absl::Status SpmdPartitioningVisitor::HandleSingleDevice(
false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
hlo->shape(), pred, operand, true_computation, operand,
false_computation));
});
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateConditional(
hlo->shape(), pred, operand, true_computation,
operand, false_computation)));
return absl::OkStatus();
}
@ -3665,10 +3661,8 @@ absl::Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
new_dims);
auto input = operand.Reshard(desired_input_sharding).hlo();
auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(
hlo->CloneWithNewOperands(output_shard_shape, {input}));
});
SetPartitionedHlo(hlo, b_.AddInstruction(hlo->CloneWithNewOperands(
output_shard_shape, {input})));
return absl::OkStatus();
}
@ -3866,15 +3860,14 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(
// Create dynamic update slice.
auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
partitioned_shape, partitioned_input, replicate_update, new_indices));
SetPartitionedHlo(hlo, [&]() {
// Select if update is needed.
return add_hlo(HloInstruction::CreateTernary(
dus->shape(), HloOpcode::kSelect,
add_hlo(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(dus->shape(), PRED),
all_dims_within_partition, {})),
dus, partitioned_input));
});
// Select if update is needed.
HloInstruction* select = add_hlo(HloInstruction::CreateTernary(
dus->shape(), HloOpcode::kSelect,
add_hlo(HloInstruction::CreateBroadcast(
ShapeUtil::ChangeElementType(dus->shape(), PRED),
all_dims_within_partition, {})),
dus, partitioned_input));
SetPartitionedHlo(hlo, select);
return absl::OkStatus();
}
@ -3893,8 +3886,7 @@ absl::Status SpmdPartitioningVisitor::HandleGetTupleElement(
PartitionedHlo source_partitioned_gte(
gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
MakePartitioningState());
source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
SetPartitionedHlo(hlo, std::move(source_partitioned_gte));
SetPartitionedHlo(hlo, source_partitioned_gte.Reshard(hlo->sharding()));
return absl::OkStatus();
}
@ -3907,19 +3899,15 @@ absl::Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
// elements for non-empty tuple. So if it has a nested empty tuple, we
// cannot invoke GetSubSharding() since it expects a sharding for the empty
// tuple. This is a workaround for that case.
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(
HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
});
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateInfeed(
shape, token, hlo->infeed_config())));
return absl::OkStatus();
}
auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
auto shard_shape = MakePartitionedShape(shape, sharding);
if (EvenlyPartitions(shape, sharding)) {
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateInfeed(
shard_shape, token, hlo->infeed_config()));
});
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateInfeed(
shard_shape, token, hlo->infeed_config())));
return absl::OkStatus();
}
@ -4023,11 +4011,11 @@ absl::Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
}
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
branches, std::vector<HloInstruction*>(branches.size(), token)));
});
SetPartitionedHlo(
hlo, b_.AddInstruction(HloInstruction::CreateConditional(
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}),
branch_index, branches,
std::vector<HloInstruction*>(branches.size(), token))));
return absl::OkStatus();
}
@ -4223,10 +4211,9 @@ absl::Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
if (!left_padded_operand) {
return DefaultAction(hlo);
}
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(hlo->CloneWithNewOperands(
left_padded_operand->shape(), {left_padded_operand}));
});
SetPartitionedHlo(hlo,
b_.AddInstruction(hlo->CloneWithNewOperands(
left_padded_operand->shape(), {left_padded_operand})));
return absl::OkStatus();
}
@ -4237,7 +4224,7 @@ absl::Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
hlo->while_body(),
GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
hlo->SetupDerivedInstruction(whileOp);
SetPartitionedHlo(hlo, [&] { return whileOp; });
SetPartitionedHlo(hlo, whileOp);
return absl::OkStatus();
}
@ -4282,21 +4269,15 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
return HandleSingleDevice(hlo);
}
if (hlo->sharding().IsManual()) {
auto clone_from_original = [&](const HloSharding& shared_sharding) {
std::vector<HloInstruction*> new_operands;
new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
}
auto clone = b_.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
clone->set_sharding(shared_sharding);
return clone;
};
SetPartitionedHlo(hlo,
[&] { return clone_from_original(hlo->sharding()); });
std::vector<HloInstruction*> new_operands;
new_operands.reserve(hlo->operand_count());
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
new_operands.push_back(
GetPartitionedHlo(hlo->operand(i)).Reshard(hlo->sharding()).hlo());
}
auto clone = b_.AddInstruction(
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
SetPartitionedHlo(hlo, clone);
return absl::OkStatus();
}
@ -4331,10 +4312,9 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
Shape outfeed_shape = operand->shape();
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
&outfeed_shape));
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateOutfeed(
outfeed_shape, operand, token, hlo->outfeed_config()));
});
SetPartitionedHlo(
hlo, b_.AddInstruction(HloInstruction::CreateOutfeed(
outfeed_shape, operand, token, hlo->outfeed_config())));
return absl::OkStatus();
}
@ -4453,13 +4433,13 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
hlo->outfeed_config()));
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateConditional(
token->shape(), branch_index, branches,
std::vector<HloInstruction*>(
branches.size(),
b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
});
SetPartitionedHlo(
hlo,
b_.AddInstruction(HloInstruction::CreateConditional(
token->shape(), branch_index, branches,
std::vector<HloInstruction*>(
branches.size(), b_.AddInstruction(HloInstruction::CreateTuple(
{operand, token}))))));
return absl::OkStatus();
}
@ -4481,8 +4461,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
};
if (hlo->sharding().IsManual()) {
SetPartitionedHlo(hlo,
[&] { return clone_from_original(hlo->sharding()); });
SetPartitionedHlo(hlo, clone_from_original(hlo->sharding()));
return absl::OkStatus();
}
@ -4507,11 +4486,10 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
}
if (!hlo->sharding().ReplicateOnLastTileDim()) {
SetPartitionedHlo(hlo, [&] {
return b_.AddInstruction(HloInstruction::CreateRng(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
hlo->random_distribution(), new_operands));
});
SetPartitionedHlo(hlo,
b_.AddInstruction(HloInstruction::CreateRng(
MakePartitionedShape(hlo->shape(), hlo->sharding()),
hlo->random_distribution(), new_operands)));
} else {
std::vector<int64_t> group_dims(
hlo->sharding().tile_assignment().num_dimensions() - 1);
@ -4830,9 +4808,8 @@ absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
.Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
.hlo());
}
SetPartitionedHlo(hlo, [&]() {
return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
});
SetPartitionedHlo(
hlo, b_.AddInstruction(HloInstruction::CreateTuple(new_operands)));
return absl::OkStatus();
}
@ -4921,7 +4898,7 @@ absl::Status SpmdPartitioningVisitor::HandleRaggedDot(HloInstruction* hlo) {
MakeBinaryAdd(phlo->shape().element_type(), lhs.state().module));
}
SetPartitionedHlo(hlo, [&]() { return phlo; });
SetPartitionedHlo(hlo, phlo);
return absl::OkStatus();
}

View File

@ -795,17 +795,21 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
void SetPartitionedHlo(const HloInstruction* hlo,
PartitionedHlo&& partitioned_hlo);
// Convenient wrapper that creates PartitionedHlo from the result of the func
// and maps it to the given original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
absl::FunctionRef<HloInstruction*()> func) {
HloInstruction* new_hlo = func();
// Convenient wrapper that creates PartitionedHlo from `new_hlo`.
void SetPartitionedHlo(const HloInstruction* hlo, HloInstruction* new_hlo) {
new_hlo->set_sharding(hlo->sharding());
SetPartitionedHlo(
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
changed_ = true;
}
// Convenient wrapper that creates PartitionedHlo from the result of the func
// and maps it to the given original hlo.
void SetPartitionedHlo(const HloInstruction* hlo,
absl::FunctionRef<HloInstruction*()> func) {
return SetPartitionedHlo(hlo, func());
}
int64_t NewChannel() { return (*next_channel_id_)++; }
PartitionedHlo::PartitioningState MakePartitioningState();