mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:SPMD] Use subgroup AllToAll for resharding
Reshard from tile [2,2,1] to [1,2,2] can be done by a subgroup all-to-all between dimensions 0 and 2. PiperOrigin-RevId: 320720720 Change-Id: I1b63ba731b830610596c77697c5577fa9e2e0f79
This commit is contained in:
parent
63e31d9508
commit
b597319553
|
|
@ -176,16 +176,45 @@ std::vector<ReplicaGroup> CreateReplicaGroups(int64 num_replicas) {
|
||||||
return groups;
|
return groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CanReshardWithAllToAll(const HloSharding& source,
|
absl::optional<std::pair<int64, int64>> GetReshardAllToAllSourceTargetDims(
|
||||||
const HloSharding& target) {
|
const HloSharding& source, const HloSharding& target) {
|
||||||
return UniqueTiledDim(source) && UniqueTiledDim(target) &&
|
if (source.IsTileMaximal() || target.IsTileMaximal() ||
|
||||||
UniqueTiledDim(source) != UniqueTiledDim(target);
|
source.tile_assignment().num_dimensions() !=
|
||||||
|
target.tile_assignment().num_dimensions()) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
int64 source_dim = -1;
|
||||||
|
int64 target_dim = -1;
|
||||||
|
for (int64 i = 0; i < source.tile_assignment().num_dimensions(); ++i) {
|
||||||
|
if (source.tile_assignment().dim(i) > 1 &&
|
||||||
|
target.tile_assignment().dim(i) == 1) {
|
||||||
|
if (source_dim != -1) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
source_dim = i;
|
||||||
|
} else if (source.tile_assignment().dim(i) == 1 &&
|
||||||
|
target.tile_assignment().dim(i) > 1) {
|
||||||
|
if (target_dim != -1) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
target_dim = i;
|
||||||
|
} else if (source.tile_assignment().dim(i) !=
|
||||||
|
target.tile_assignment().dim(i)) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (source_dim == -1 || target_dim == -1 || source_dim == target_dim) {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
return std::pair(source_dim, target_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
bool CanReshardWithCollectivePermute(const HloSharding& source,
|
||||||
const HloSharding& target) {
|
const HloSharding& target) {
|
||||||
return UniqueTiledDim(source) && UniqueTiledDim(target) &&
|
return !source.IsTileMaximal() && !target.IsTileMaximal() &&
|
||||||
UniqueTiledDim(source) == UniqueTiledDim(target) && source != target;
|
source.tile_assignment().dimensions() ==
|
||||||
|
target.tile_assignment().dimensions() &&
|
||||||
|
source.tile_assignment() != target.tile_assignment();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clears all sharding attributes from instructions in the module. This must be
|
// Clears all sharding attributes from instructions in the module. This must be
|
||||||
|
|
@ -278,8 +307,10 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||||
return ReshardWithCollectivePermute(target);
|
return ReshardWithCollectivePermute(target);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (CanReshardWithAllToAll(sharding(), target)) {
|
if (auto src_tgt_dims =
|
||||||
return ReshardWithAllToAll(target);
|
GetReshardAllToAllSourceTargetDims(sharding(), target)) {
|
||||||
|
return ReshardWithAllToAll(target, src_tgt_dims->first,
|
||||||
|
src_tgt_dims->second);
|
||||||
}
|
}
|
||||||
|
|
||||||
// If not replicated yet, first replicate and then reshard to use one of the
|
// If not replicated yet, first replicate and then reshard to use one of the
|
||||||
|
|
@ -745,45 +776,53 @@ PartitionedHlo PartitionedHlo::Broadcast() const {
|
||||||
return PartitionedHlo(result, base_shape_, state_);
|
return PartitionedHlo(result, base_shape_, state_);
|
||||||
}
|
}
|
||||||
|
|
||||||
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
PartitionedHlo PartitionedHlo::ReshardWithAllToAll(const HloSharding& target,
|
||||||
const HloSharding& target) const {
|
int64 source_dim,
|
||||||
int64 partition_count = sharding().tile_assignment().num_elements();
|
int64 target_dim) const {
|
||||||
absl::optional<int64> input_partition_dim = UniqueTiledDim(sharding());
|
const int64 group_size = sharding().tile_assignment().dim(source_dim);
|
||||||
absl::optional<int64> output_partition_dim = UniqueTiledDim(target);
|
|
||||||
CHECK(input_partition_dim.has_value());
|
|
||||||
CHECK(output_partition_dim.has_value());
|
|
||||||
|
|
||||||
// If the device order is different in the target, fix the order with
|
// If the device order is different in the target, fix the order with
|
||||||
// ReshardWithCollectivePermute.
|
// ReshardWithCollectivePermute.
|
||||||
auto input_tile_fixed_device_order = target.tile_assignment();
|
std::vector<int64> xpose_dims(target.tile_assignment().num_dimensions());
|
||||||
input_tile_fixed_device_order.Reshape(
|
std::iota(xpose_dims.begin(), xpose_dims.end(), 0);
|
||||||
sharding().tile_assignment().dimensions());
|
xpose_dims[source_dim] = target_dim;
|
||||||
|
xpose_dims[target_dim] = source_dim;
|
||||||
auto input_sharding_fixed_device_order =
|
auto input_sharding_fixed_device_order =
|
||||||
HloSharding::Tile(input_tile_fixed_device_order);
|
hlo_sharding_util::TransposeSharding(target, xpose_dims);
|
||||||
if (input_sharding_fixed_device_order != sharding()) {
|
if (input_sharding_fixed_device_order != sharding()) {
|
||||||
auto fixed_order =
|
auto fixed_order =
|
||||||
ReshardWithCollectivePermute(input_sharding_fixed_device_order);
|
ReshardWithCollectivePermute(input_sharding_fixed_device_order);
|
||||||
return fixed_order.ReshardWithAllToAll(target);
|
return fixed_order.ReshardWithAllToAll(target, source_dim, target_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto padded_hlo =
|
auto padded_hlo =
|
||||||
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
|
PadBaseShapeBeforeUnevenTiledSharding(hlo_, target, state_.b);
|
||||||
|
|
||||||
// The order of ids in the group must follow the target sharding.
|
// The order of ids in the group must follow the target sharding.
|
||||||
std::vector<ReplicaGroup> groups(1);
|
std::vector<ReplicaGroup> groups(target.tile_assignment().num_elements() /
|
||||||
for (int64 device : target.tile_assignment()) {
|
group_size);
|
||||||
groups[0].add_replica_ids(device);
|
target.tile_assignment().Each(
|
||||||
|
[&](absl::Span<const int64> indices, int64 device) {
|
||||||
|
int64 group_id = 0;
|
||||||
|
for (int64 dim = 0; dim < indices.size(); ++dim) {
|
||||||
|
if (dim == target_dim) {
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
group_id *= target.tile_assignment().dim(dim);
|
||||||
|
group_id += indices[dim];
|
||||||
|
}
|
||||||
|
groups[group_id].add_replica_ids(device);
|
||||||
|
});
|
||||||
|
|
||||||
HloInstruction* result = nullptr;
|
HloInstruction* result = nullptr;
|
||||||
|
|
||||||
// Split along the split dimension (output_partition_dim) of the all-to-all
|
// Split along the split dimension (target_dim) of the all-to-all
|
||||||
// output.
|
// output.
|
||||||
std::vector<int64> dimensions;
|
std::vector<int64> dimensions;
|
||||||
for (int64 i = 0; i < base_shape_.rank(); ++i) {
|
for (int64 i = 0; i < base_shape_.rank(); ++i) {
|
||||||
if (i == *output_partition_dim) {
|
if (i == target_dim) {
|
||||||
dimensions.push_back(partition_count);
|
dimensions.push_back(group_size);
|
||||||
dimensions.push_back(padded_hlo->shape().dimensions(i) / partition_count);
|
dimensions.push_back(padded_hlo->shape().dimensions(i) / group_size);
|
||||||
} else {
|
} else {
|
||||||
dimensions.push_back(padded_hlo->shape().dimensions(i));
|
dimensions.push_back(padded_hlo->shape().dimensions(i));
|
||||||
}
|
}
|
||||||
|
|
@ -794,21 +833,19 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||||
// After the reshape, it is guaranteed to have at least 3 dimensions.
|
// After the reshape, it is guaranteed to have at least 3 dimensions.
|
||||||
auto all_to_all =
|
auto all_to_all =
|
||||||
state_.collective_ops_creator.create_cross_partition_all_to_all(
|
state_.collective_ops_creator.create_cross_partition_all_to_all(
|
||||||
state_.b, {reshape}, groups, (*state_.next_channel_id)++,
|
state_.b, {reshape}, groups, (*state_.next_channel_id)++, target_dim);
|
||||||
output_partition_dim);
|
|
||||||
|
|
||||||
// Reorder the split dimension of the reshape to be located in front of the
|
// Reorder the split dimension of the reshape to be located in front of the
|
||||||
// input partition dimension, so the two dimensions can be combined.
|
// input partition dimension, so the two dimensions can be combined.
|
||||||
int64 new_input_partition_dim = (*output_partition_dim < *input_partition_dim)
|
int64 new_source_dim =
|
||||||
? *input_partition_dim + 1
|
(target_dim < source_dim) ? source_dim + 1 : source_dim;
|
||||||
: *input_partition_dim;
|
|
||||||
std::vector<int64> permutation;
|
std::vector<int64> permutation;
|
||||||
for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
|
for (int64 i = 0; i < all_to_all->shape().rank(); ++i) {
|
||||||
if (i == *output_partition_dim) {
|
if (i == target_dim) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (i == new_input_partition_dim) {
|
if (i == new_source_dim) {
|
||||||
permutation.push_back(*output_partition_dim);
|
permutation.push_back(target_dim);
|
||||||
}
|
}
|
||||||
permutation.push_back(i);
|
permutation.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
@ -819,8 +856,7 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||||
|
|
||||||
// Combine the split dimension and the input partition dimension.
|
// Combine the split dimension and the input partition dimension.
|
||||||
auto new_shape = ShapeInference::InferAllToAllShape(
|
auto new_shape = ShapeInference::InferAllToAllShape(
|
||||||
padded_hlo->shape(), *output_partition_dim,
|
padded_hlo->shape(), target_dim, source_dim, group_size)
|
||||||
*input_partition_dim, partition_count)
|
|
||||||
.ValueOrDie();
|
.ValueOrDie();
|
||||||
result = state_.b->AddInstruction(
|
result = state_.b->AddInstruction(
|
||||||
HloInstruction::CreateReshape(new_shape, transpose));
|
HloInstruction::CreateReshape(new_shape, transpose));
|
||||||
|
|
@ -837,7 +873,8 @@ PartitionedHlo PartitionedHlo::ReshardWithAllToAll(
|
||||||
|
|
||||||
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute(
|
||||||
const HloSharding& target) const {
|
const HloSharding& target) const {
|
||||||
CHECK(CanReshardWithCollectivePermute(sharding(), target));
|
CHECK(CanReshardWithCollectivePermute(sharding(), target))
|
||||||
|
<< sharding().ToString() << " to " << target.ToString();
|
||||||
std::vector<std::pair<int64, int64>> src_dst_pairs;
|
std::vector<std::pair<int64, int64>> src_dst_pairs;
|
||||||
sharding().tile_assignment().Each(
|
sharding().tile_assignment().Each(
|
||||||
[&](absl::Span<const int64> indices, int64 src_device) {
|
[&](absl::Span<const int64> indices, int64 src_device) {
|
||||||
|
|
@ -3653,8 +3690,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||||
output_batch_partitions == num_partitions_ &&
|
output_batch_partitions == num_partitions_ &&
|
||||||
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
lhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||||
if (!may_reshard_with_allreduce &&
|
if (!may_reshard_with_allreduce &&
|
||||||
!CanReshardWithAllToAll(rhs.sharding(),
|
!GetReshardAllToAllSourceTargetDims(
|
||||||
*lhs_sharding_transposed_to_match_rhs)) {
|
rhs.sharding(), *lhs_sharding_transposed_to_match_rhs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
|
auto resharded_rhs = rhs.Reshard(*lhs_sharding_transposed_to_match_rhs);
|
||||||
|
|
@ -3668,8 +3705,8 @@ Status SpmdPartitioningVisitor::HandleDotHelper(
|
||||||
output_batch_partitions == num_partitions_ &&
|
output_batch_partitions == num_partitions_ &&
|
||||||
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
rhs_sharding_transposed_to_match_output == hlo->sharding()) {
|
||||||
if (!may_reshard_with_allreduce &&
|
if (!may_reshard_with_allreduce &&
|
||||||
!CanReshardWithAllToAll(lhs.sharding(),
|
!GetReshardAllToAllSourceTargetDims(
|
||||||
*rhs_sharding_transposed_to_match_lhs)) {
|
lhs.sharding(), *rhs_sharding_transposed_to_match_lhs)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
|
auto resharded_lhs = lhs.Reshard(*rhs_sharding_transposed_to_match_lhs);
|
||||||
|
|
|
||||||
|
|
@ -284,7 +284,8 @@ class PartitionedHlo {
|
||||||
|
|
||||||
// Helper function to reshard the tensor using AllToAll (instead of the
|
// Helper function to reshard the tensor using AllToAll (instead of the
|
||||||
// default of Replicate followed by Slice).
|
// default of Replicate followed by Slice).
|
||||||
PartitionedHlo ReshardWithAllToAll(const HloSharding& target) const;
|
PartitionedHlo ReshardWithAllToAll(const HloSharding& target,
|
||||||
|
int64 source_dim, int64 target_dim) const;
|
||||||
|
|
||||||
// Helper function to reshard the tensor using CollectivePermute.
|
// Helper function to reshard the tensor using CollectivePermute.
|
||||||
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
|
PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const;
|
||||||
|
|
|
||||||
|
|
@ -3766,6 +3766,32 @@ ENTRY entry {
|
||||||
op::Parameter(0))));
|
op::Parameter(0))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpmdPartitioningTest, SubgroupAllToAllReshard) {
|
||||||
|
const char* const hlo_string = R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
%param0 = f32[8,8,8,8] parameter(0),
|
||||||
|
sharding={devices=[2,2,1,2]0,1,2,3,4,5,6,7}
|
||||||
|
ROOT %copy = f32[8,8,8,8] copy(%param0),
|
||||||
|
sharding={devices=[1,2,2,2]0,1,4,5,2,3,6,7}
|
||||||
|
})";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||||
|
VLOG(1) << module->ToString();
|
||||||
|
|
||||||
|
auto root = module->entry_computation()->root_instruction();
|
||||||
|
auto reshape =
|
||||||
|
AllOf(op::Shape("f32[4,4,2,4,4]"), op::Reshape(op::Parameter(0)));
|
||||||
|
auto all_to_all = AllOf(op::Shape("f32[4,4,2,4,4]"), op::AllToAll(reshape));
|
||||||
|
auto xpose = AllOf(op::Shape("f32[2,4,4,4,4]"), op::Transpose(all_to_all));
|
||||||
|
EXPECT_THAT(root,
|
||||||
|
op::Copy(AllOf(op::Reshape(xpose), op::Shape("f32[8,4,4,4]"))));
|
||||||
|
EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->replica_groups().size(),
|
||||||
|
4);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace spmd
|
} // namespace spmd
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user