mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add an overload for SpmdPartitioner::SetPartitionedHlo to avoid unnecessary lambda functions.
PiperOrigin-RevId: 825819367
This commit is contained in:
parent
512f1e48cb
commit
ba10feaa24
|
|
@ -341,7 +341,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
|
|||
};
|
||||
HloInstruction* rotated0 = rotate_with_padding(amount);
|
||||
if (right_padding == 0) {
|
||||
SetPartitionedHlo(hlo, [&] { return rotated0; });
|
||||
SetPartitionedHlo(hlo, rotated0);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -374,10 +374,9 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCallSPMDInternal_RotateRight(
|
|||
HloInstruction* pred = b_.AddInstruction(HloInstruction::CreateCompare(
|
||||
ShapeUtil::ChangeElementType(iota->shape(), PRED), iota,
|
||||
selection_boundary, Comparison::Direction::kLt));
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(HloInstruction::CreateTernary(
|
||||
rotated0->shape(), HloOpcode::kSelect, pred, rotated1, rotated0));
|
||||
});
|
||||
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateTernary(
|
||||
rotated0->shape(), HloOpcode::kSelect, pred,
|
||||
rotated1, rotated0)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -405,7 +404,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
|||
input->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
|
||||
auto copy = b_.AddInstruction(
|
||||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
|
||||
SetPartitionedHlo(hlo, [&] { return copy; });
|
||||
SetPartitionedHlo(hlo, copy);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
if (hlo->custom_call_target() == "SPMDShardToFullShape") {
|
||||
|
|
@ -416,7 +415,7 @@ absl::Status SpmdPartitioningVisitor::HandleCustomCall(HloInstruction* hlo) {
|
|||
HloInstruction::CreateUnary(input->shape(), HloOpcode::kCopy, input));
|
||||
CHECK(ShapeUtil::Compatible(
|
||||
copy->shape(), MakePartitionedShape(hlo->shape(), hlo->sharding())));
|
||||
SetPartitionedHlo(hlo, [&] { return copy; });
|
||||
SetPartitionedHlo(hlo, copy);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4349,7 +4349,7 @@ absl::Status SpmdPartitioningVisitor::HandleDotHelper(
|
|||
num_partitions_, create_sharded_dot, conv_window, module_,
|
||||
hlo, options_, &b_, &windowed_dot_general_loops_, this));
|
||||
}
|
||||
SetPartitionedHlo(hlo, [partitioned_dot] { return partitioned_dot; });
|
||||
SetPartitionedHlo(hlo, partitioned_dot);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -426,10 +426,7 @@ absl::Status SpmdPartitioningVisitor::HandleFft(HloInstruction* hlo) {
|
|||
partitioned_input.state().next_channel_id, module_,
|
||||
partitioned_input.state().b);
|
||||
|
||||
result->set_sharding(hlo->sharding());
|
||||
auto partitioned_fft =
|
||||
PartitionedHlo(result, hlo->shape(), partitioned_input.state());
|
||||
SetPartitionedHlo(hlo, std::move(partitioned_fft));
|
||||
SetPartitionedHlo(hlo, result);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1009,8 +1009,7 @@ absl::Status SpmdPartitioningVisitor::HandleGather(HloInstruction* hlo) {
|
|||
PartitionGather(gather, operand, indices, gather->shape(),
|
||||
gather->sharding(), absl::MakeConstSpan(batch_dims),
|
||||
gather->gather_slice_sizes(), this));
|
||||
SetPartitionedHlo(gather, PartitionedHlo(pgather, gather->shape(),
|
||||
MakePartitioningState()));
|
||||
SetPartitionedHlo(gather, pgather);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -1904,8 +1903,7 @@ absl::Status SpmdPartitioningVisitor::HandleScatter(HloInstruction* hlo) {
|
|||
if (!pscatter) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
SetPartitionedHlo(scatter, PartitionedHlo(pscatter, scatter->shape(),
|
||||
MakePartitioningState()));
|
||||
SetPartitionedHlo(scatter, pscatter);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
163
third_party/xla/xla/service/spmd/spmd_partitioner.cc
vendored
163
third_party/xla/xla/service/spmd/spmd_partitioner.cc
vendored
|
|
@ -2782,10 +2782,10 @@ absl::Status SpmdPartitioningVisitor::HandleElementwise(HloInstruction* hlo) {
|
|||
new_operands.push_back(
|
||||
GetPartitionedHlo(operand).Reshard(hlo->sharding()).hlo());
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo,
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -2908,7 +2908,7 @@ absl::Status SpmdPartitioningVisitor::HandleSlice(HloInstruction* hlo) {
|
|||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
SetPartitionedHlo(hlo, [&] { return final_operand; });
|
||||
SetPartitionedHlo(hlo, final_operand);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -2933,9 +2933,7 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
|
|||
}
|
||||
auto clone = b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
|
||||
clone->set_sharding(sharding);
|
||||
SetPartitionedHlo(
|
||||
hlo, PartitionedHlo(clone, hlo->shape(), MakePartitioningState()));
|
||||
SetPartitionedHlo(hlo, clone);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
// Special handling for sort in TopK when first operand partitioined at
|
||||
|
|
@ -3126,10 +3124,10 @@ absl::Status SpmdPartitioningVisitor::HandleSort(HloInstruction* hlo) {
|
|||
for (HloInstruction* operand : hlo->operands()) {
|
||||
new_operands.push_back(GetPartitionedHlo(operand).Reshard(sharding).hlo());
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo,
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), new_operands)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3149,10 +3147,10 @@ absl::Status SpmdPartitioningVisitor::HandleTranspose(HloInstruction* hlo) {
|
|||
auto operand = GetPartitionedHlo(hlo->operand(0))
|
||||
.Reshard(desired_operand_sharding)
|
||||
.hlo();
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand}));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo,
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()), {operand})));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3199,7 +3197,7 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||
PartitionedHlo reshard_reshape =
|
||||
PartitionedHlo(reshape, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(sharding);
|
||||
SetPartitionedHlo(hlo, [&] { return reshard_reshape.hlo(); });
|
||||
SetPartitionedHlo(hlo, reshard_reshape.hlo());
|
||||
|
||||
if (sharding_pairs.size() == 2 &&
|
||||
sharding_pairs[1].first == operand.sharding() &&
|
||||
|
|
@ -3456,7 +3454,7 @@ absl::Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||
};
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * partitioned,
|
||||
recursive_shard(operand, sharding, hlo->shape()));
|
||||
SetPartitionedHlo(hlo, [&] { return partitioned; });
|
||||
SetPartitionedHlo(hlo, partitioned);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3545,11 +3543,9 @@ absl::Status SpmdPartitioningVisitor::HandleSingleDevice(
|
|||
false_computation = module_->AddEmbeddedComputation(false_b.Build(root));
|
||||
}
|
||||
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
hlo->shape(), pred, operand, true_computation, operand,
|
||||
false_computation));
|
||||
});
|
||||
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
hlo->shape(), pred, operand, true_computation,
|
||||
operand, false_computation)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3665,10 +3661,8 @@ absl::Status SpmdPartitioningVisitor::HandleBroadcast(HloInstruction* hlo) {
|
|||
new_dims);
|
||||
auto input = operand.Reshard(desired_input_sharding).hlo();
|
||||
auto output_shard_shape = MakePartitionedShape(hlo->shape(), hlo->sharding());
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(output_shard_shape, {input}));
|
||||
});
|
||||
SetPartitionedHlo(hlo, b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
output_shard_shape, {input})));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3866,15 +3860,14 @@ absl::Status SpmdPartitioningVisitor::HandleDynamicUpdateSlice(
|
|||
// Create dynamic update slice.
|
||||
auto dus = add_hlo(HloInstruction::CreateDynamicUpdateSlice(
|
||||
partitioned_shape, partitioned_input, replicate_update, new_indices));
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
// Select if update is needed.
|
||||
return add_hlo(HloInstruction::CreateTernary(
|
||||
dus->shape(), HloOpcode::kSelect,
|
||||
add_hlo(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(dus->shape(), PRED),
|
||||
all_dims_within_partition, {})),
|
||||
dus, partitioned_input));
|
||||
});
|
||||
// Select if update is needed.
|
||||
HloInstruction* select = add_hlo(HloInstruction::CreateTernary(
|
||||
dus->shape(), HloOpcode::kSelect,
|
||||
add_hlo(HloInstruction::CreateBroadcast(
|
||||
ShapeUtil::ChangeElementType(dus->shape(), PRED),
|
||||
all_dims_within_partition, {})),
|
||||
dus, partitioned_input));
|
||||
SetPartitionedHlo(hlo, select);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3893,8 +3886,7 @@ absl::Status SpmdPartitioningVisitor::HandleGetTupleElement(
|
|||
PartitionedHlo source_partitioned_gte(
|
||||
gte, tuple.base_shape().tuple_shapes(hlo->tuple_index()),
|
||||
MakePartitioningState());
|
||||
source_partitioned_gte = source_partitioned_gte.Reshard(hlo->sharding());
|
||||
SetPartitionedHlo(hlo, std::move(source_partitioned_gte));
|
||||
SetPartitionedHlo(hlo, source_partitioned_gte.Reshard(hlo->sharding()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -3907,19 +3899,15 @@ absl::Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
|
|||
// elements for non-empty tuple. So if it has a nested empty tuple, we
|
||||
// cannot invoke GetSubSharding() since it expects a sharding for the empty
|
||||
// tuple. This is a workaround for that case.
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(
|
||||
HloInstruction::CreateInfeed(shape, token, hlo->infeed_config()));
|
||||
});
|
||||
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateInfeed(
|
||||
shape, token, hlo->infeed_config())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
auto sharding = hlo->sharding().GetSubSharding(hlo->shape(), {0});
|
||||
auto shard_shape = MakePartitionedShape(shape, sharding);
|
||||
if (EvenlyPartitions(shape, sharding)) {
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateInfeed(
|
||||
shard_shape, token, hlo->infeed_config()));
|
||||
});
|
||||
SetPartitionedHlo(hlo, b_.AddInstruction(HloInstruction::CreateInfeed(
|
||||
shard_shape, token, hlo->infeed_config())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4023,11 +4011,11 @@ absl::Status SpmdPartitioningVisitor::HandleInfeed(HloInstruction* hlo) {
|
|||
}
|
||||
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}), branch_index,
|
||||
branches, std::vector<HloInstruction*>(branches.size(), token)));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo, b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
ShapeUtil::MakeTupleShape({shard_shape, token->shape()}),
|
||||
branch_index, branches,
|
||||
std::vector<HloInstruction*>(branches.size(), token))));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4223,10 +4211,9 @@ absl::Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
|
|||
if (!left_padded_operand) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
left_padded_operand->shape(), {left_padded_operand}));
|
||||
});
|
||||
SetPartitionedHlo(hlo,
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
left_padded_operand->shape(), {left_padded_operand})));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4237,7 +4224,7 @@ absl::Status SpmdPartitioningVisitor::HandleWhile(HloInstruction* hlo) {
|
|||
hlo->while_body(),
|
||||
GetPartitionedHlo(hlo->operand(0)).Reshard(sharding).hlo()));
|
||||
hlo->SetupDerivedInstruction(whileOp);
|
||||
SetPartitionedHlo(hlo, [&] { return whileOp; });
|
||||
SetPartitionedHlo(hlo, whileOp);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4282,21 +4269,15 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
|||
return HandleSingleDevice(hlo);
|
||||
}
|
||||
if (hlo->sharding().IsManual()) {
|
||||
auto clone_from_original = [&](const HloSharding& shared_sharding) {
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
new_operands.reserve(hlo->operand_count());
|
||||
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
|
||||
new_operands.push_back(
|
||||
GetPartitionedHlo(hlo->operand(i)).Reshard(shared_sharding).hlo());
|
||||
}
|
||||
auto clone = b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
|
||||
clone->set_sharding(shared_sharding);
|
||||
return clone;
|
||||
};
|
||||
|
||||
SetPartitionedHlo(hlo,
|
||||
[&] { return clone_from_original(hlo->sharding()); });
|
||||
std::vector<HloInstruction*> new_operands;
|
||||
new_operands.reserve(hlo->operand_count());
|
||||
for (int64_t i = 0; i < hlo->operand_count(); ++i) {
|
||||
new_operands.push_back(
|
||||
GetPartitionedHlo(hlo->operand(i)).Reshard(hlo->sharding()).hlo());
|
||||
}
|
||||
auto clone = b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(hlo->shape(), new_operands));
|
||||
SetPartitionedHlo(hlo, clone);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4331,10 +4312,9 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
|||
Shape outfeed_shape = operand->shape();
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(hlo->outfeed_shape(),
|
||||
&outfeed_shape));
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateOutfeed(
|
||||
outfeed_shape, operand, token, hlo->outfeed_config()));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo, b_.AddInstruction(HloInstruction::CreateOutfeed(
|
||||
outfeed_shape, operand, token, hlo->outfeed_config())));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4453,13 +4433,13 @@ absl::Status SpmdPartitioningVisitor::HandleOutfeed(HloInstruction* hlo) {
|
|||
hlo->outfeed_config()));
|
||||
branches[i] = module_->AddEmbeddedComputation(branch_b.Build());
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
token->shape(), branch_index, branches,
|
||||
std::vector<HloInstruction*>(
|
||||
branches.size(),
|
||||
b_.AddInstruction(HloInstruction::CreateTuple({operand, token})))));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo,
|
||||
b_.AddInstruction(HloInstruction::CreateConditional(
|
||||
token->shape(), branch_index, branches,
|
||||
std::vector<HloInstruction*>(
|
||||
branches.size(), b_.AddInstruction(HloInstruction::CreateTuple(
|
||||
{operand, token}))))));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4481,8 +4461,7 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
|
|||
};
|
||||
|
||||
if (hlo->sharding().IsManual()) {
|
||||
SetPartitionedHlo(hlo,
|
||||
[&] { return clone_from_original(hlo->sharding()); });
|
||||
SetPartitionedHlo(hlo, clone_from_original(hlo->sharding()));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4507,11 +4486,10 @@ absl::Status SpmdPartitioningVisitor::HandleRng(HloInstruction* hlo) {
|
|||
}
|
||||
|
||||
if (!hlo->sharding().ReplicateOnLastTileDim()) {
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(HloInstruction::CreateRng(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()),
|
||||
hlo->random_distribution(), new_operands));
|
||||
});
|
||||
SetPartitionedHlo(hlo,
|
||||
b_.AddInstruction(HloInstruction::CreateRng(
|
||||
MakePartitionedShape(hlo->shape(), hlo->sharding()),
|
||||
hlo->random_distribution(), new_operands)));
|
||||
} else {
|
||||
std::vector<int64_t> group_dims(
|
||||
hlo->sharding().tile_assignment().num_dimensions() - 1);
|
||||
|
|
@ -4830,9 +4808,8 @@ absl::Status SpmdPartitioningVisitor::HandleTuple(HloInstruction* hlo) {
|
|||
.Reshard(hlo->sharding().GetSubSharding(hlo->shape(), {i}))
|
||||
.hlo());
|
||||
}
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
return b_.AddInstruction(HloInstruction::CreateTuple(new_operands));
|
||||
});
|
||||
SetPartitionedHlo(
|
||||
hlo, b_.AddInstruction(HloInstruction::CreateTuple(new_operands)));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
@ -4921,7 +4898,7 @@ absl::Status SpmdPartitioningVisitor::HandleRaggedDot(HloInstruction* hlo) {
|
|||
MakeBinaryAdd(phlo->shape().element_type(), lhs.state().module));
|
||||
}
|
||||
|
||||
SetPartitionedHlo(hlo, [&]() { return phlo; });
|
||||
SetPartitionedHlo(hlo, phlo);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -795,17 +795,21 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
|||
void SetPartitionedHlo(const HloInstruction* hlo,
|
||||
PartitionedHlo&& partitioned_hlo);
|
||||
|
||||
// Convenient wrapper that creates PartitionedHlo from the result of the func
|
||||
// and maps it to the given original hlo.
|
||||
void SetPartitionedHlo(const HloInstruction* hlo,
|
||||
absl::FunctionRef<HloInstruction*()> func) {
|
||||
HloInstruction* new_hlo = func();
|
||||
// Convenient wrapper that creates PartitionedHlo from `new_hlo`.
|
||||
void SetPartitionedHlo(const HloInstruction* hlo, HloInstruction* new_hlo) {
|
||||
new_hlo->set_sharding(hlo->sharding());
|
||||
SetPartitionedHlo(
|
||||
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
|
||||
changed_ = true;
|
||||
}
|
||||
|
||||
// Convenient wrapper that creates PartitionedHlo from the result of the func
|
||||
// and maps it to the given original hlo.
|
||||
void SetPartitionedHlo(const HloInstruction* hlo,
|
||||
absl::FunctionRef<HloInstruction*()> func) {
|
||||
return SetPartitionedHlo(hlo, func());
|
||||
}
|
||||
|
||||
int64_t NewChannel() { return (*next_channel_id_)++; }
|
||||
|
||||
PartitionedHlo::PartitioningState MakePartitioningState();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user