mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[ReplicaGroupV3][MeshAxesReplicaGroupList][1/2] Add initial class definition for V3 replica group.
PiperOrigin-RevId: 826334561
This commit is contained in:
parent
d9c76aafeb
commit
cef240807a
1
third_party/xla/xla/array.h
vendored
1
third_party/xla/xla/array.h
vendored
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <initializer_list>
|
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
|
||||||
9
third_party/xla/xla/hlo/ir/BUILD
vendored
9
third_party/xla/xla/hlo/ir/BUILD
vendored
|
|
@ -60,6 +60,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":backend_config",
|
":backend_config",
|
||||||
":hlo_sharding",
|
":hlo_sharding",
|
||||||
|
":mesh_and_axis",
|
||||||
":named_sharding",
|
":named_sharding",
|
||||||
":ptrvec",
|
":ptrvec",
|
||||||
":tile_assignment",
|
":tile_assignment",
|
||||||
|
|
@ -123,6 +124,7 @@ cc_library(
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@com_google_absl//absl/strings:str_format",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
"@com_google_protobuf//:protobuf_lite",
|
||||||
"@highwayhash",
|
"@highwayhash",
|
||||||
"@highwayhash//:arch_specific",
|
"@highwayhash//:arch_specific",
|
||||||
"@highwayhash//:hh_types",
|
"@highwayhash//:hh_types",
|
||||||
|
|
@ -188,6 +190,7 @@ cc_library(
|
||||||
"//xla:xla_data_proto_cc",
|
"//xla:xla_data_proto_cc",
|
||||||
"@com_google_absl//absl/algorithm:container",
|
"@com_google_absl//absl/algorithm:container",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@llvm-project//llvm:Support",
|
"@llvm-project//llvm:Support",
|
||||||
],
|
],
|
||||||
|
|
@ -415,9 +418,13 @@ xla_cc_test(
|
||||||
srcs = ["replica_group_test.cc"],
|
srcs = ["replica_group_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":mesh_and_axis",
|
||||||
|
":tile_assignment",
|
||||||
|
"//xla:array",
|
||||||
"//xla:xla_data_proto_cc",
|
"//xla:xla_data_proto_cc",
|
||||||
"//xla/service:hlo_proto_cc",
|
"//xla/service:hlo_proto_cc",
|
||||||
"//xla/tsl/platform:test_main",
|
"//xla/tsl/platform:test_main",
|
||||||
|
"@com_google_absl//absl/base:log_severity",
|
||||||
"@com_google_googletest//:gtest",
|
"@com_google_googletest//:gtest",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -433,8 +440,6 @@ xla_cc_test(
|
||||||
"//xla/tsl/util/proto:proto_matchers",
|
"//xla/tsl/util/proto:proto_matchers",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
"@local_tsl//tsl/platform:status_matchers",
|
|
||||||
"@local_tsl//tsl/platform:test",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
36
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
36
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
|
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "xla/hlo/ir/tile_assignment.h"
|
#include "xla/hlo/ir/tile_assignment.h"
|
||||||
#include "xla/xla_data.pb.h"
|
#include "xla/xla_data.pb.h"
|
||||||
|
|
@ -55,6 +57,28 @@ class Mesh {
|
||||||
axes_names_ == other.axes_names_;
|
axes_names_ == other.axes_names_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ToString() const {
|
||||||
|
std::string mesh_str = "@mesh";
|
||||||
|
// Add the mesh axes names and sizes.
|
||||||
|
std::vector<std::string> formatted_axes_names;
|
||||||
|
formatted_axes_names.reserve(axes_names_.size());
|
||||||
|
for (int64_t i = 0; i < axes_names_.size(); ++i) {
|
||||||
|
formatted_axes_names.push_back(
|
||||||
|
absl::StrCat(axes_names_[i], "=", device_assignment_.dim(i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the device assignment if it is not an iota case.
|
||||||
|
std::optional<IotaTileAssignment> iota = device_assignment_.iota();
|
||||||
|
std::string device_assignment_str = "";
|
||||||
|
if (!(iota.has_value() && iota->reshape_dims().size() == 1)) {
|
||||||
|
device_assignment_str =
|
||||||
|
absl::StrCat("(", device_assignment_.ArrayToString(), ")");
|
||||||
|
}
|
||||||
|
absl::StrAppend(&mesh_str, "<", absl::StrJoin(formatted_axes_names, ","),
|
||||||
|
">", device_assignment_str);
|
||||||
|
return mesh_str;
|
||||||
|
}
|
||||||
|
|
||||||
bool operator!=(const Mesh& other) const { return !(*this == other); }
|
bool operator!=(const Mesh& other) const { return !(*this == other); }
|
||||||
|
|
||||||
MeshProto ToProto() const;
|
MeshProto ToProto() const;
|
||||||
|
|
@ -62,6 +86,7 @@ class Mesh {
|
||||||
static Mesh FromProto(const MeshProto& proto);
|
static Mesh FromProto(const MeshProto& proto);
|
||||||
|
|
||||||
TileAssignment device_assignment() const { return device_assignment_; }
|
TileAssignment device_assignment() const { return device_assignment_; }
|
||||||
|
std::vector<std::string> axis_names() const { return axes_names_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Dimensions of the `device_assignment_` array correspond to the axes of the
|
// Dimensions of the `device_assignment_` array correspond to the axes of the
|
||||||
|
|
@ -113,6 +138,17 @@ class AxisRef {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ToString(const Mesh& mesh) const {
|
||||||
|
CHECK_GE(mesh_axis_index_, 0);
|
||||||
|
CHECK_LT(mesh_axis_index_, mesh.axis_names().size());
|
||||||
|
std::string axis_str = mesh.axis_names()[mesh_axis_index()];
|
||||||
|
if (sub_axis_info_.has_value()) {
|
||||||
|
absl::StrAppend(&axis_str, ":(", sub_axis_info_->pre_size, ")",
|
||||||
|
sub_axis_info_->size);
|
||||||
|
}
|
||||||
|
return axis_str;
|
||||||
|
}
|
||||||
|
|
||||||
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
|
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
|
||||||
|
|
||||||
AxisRefProto ToProto() const;
|
AxisRefProto ToProto() const;
|
||||||
|
|
|
||||||
19
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
19
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
|
|
@ -176,4 +176,23 @@ TEST(MeshAndAxisTest, MeshRoundtripProto) {
|
||||||
EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto()));
|
EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||||
|
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{10, 12, 15})),
|
||||||
|
/*axes_names=*/{"u", "v", "w"});
|
||||||
|
EXPECT_EQ(mesh_uvw.ToString(), "@mesh<u=10,v=12,w=15>");
|
||||||
|
|
||||||
|
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||||
|
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||||
|
/*axes_names=*/{"a", "b", "c", "d"});
|
||||||
|
EXPECT_EQ(mesh_abcd.ToString(), "@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0))");
|
||||||
|
|
||||||
|
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||||
|
array.Reshape({10});
|
||||||
|
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||||
|
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||||
|
EXPECT_EQ(mesh_ooo.ToString(), "@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9)");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
61
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
61
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -26,6 +27,8 @@ limitations under the License.
|
||||||
#include "absl/strings/str_join.h"
|
#include "absl/strings/str_join.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "xla/array.h"
|
#include "xla/array.h"
|
||||||
|
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||||
|
#include "xla/hlo/ir/tile_assignment.h"
|
||||||
#include "xla/printer.h"
|
#include "xla/printer.h"
|
||||||
#include "xla/service/hlo.pb.h"
|
#include "xla/service/hlo.pb.h"
|
||||||
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
|
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
|
||||||
|
|
@ -45,6 +48,63 @@ std::string ReplicaGroupsToString(
|
||||||
return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}");
|
return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/************** MeshAxesReplicaGroupList implementation ***********************/
|
||||||
|
int64_t MeshAxesReplicaGroupList::num_replica_groups() const {
|
||||||
|
return mesh_.device_assignment().num_elements() / num_devices_per_group();
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t MeshAxesReplicaGroupList::num_devices_per_group() const {
|
||||||
|
// Number of devices per replica group is equal to the product of the sizes of
|
||||||
|
// all axes.
|
||||||
|
int64_t devices_per_group = 1;
|
||||||
|
for (const AxisRef& axis : axes_) {
|
||||||
|
int64_t axis_size =
|
||||||
|
axis.sub_axis_info().has_value()
|
||||||
|
? axis.sub_axis_info()->size
|
||||||
|
: mesh_.device_assignment().dim(axis.mesh_axis_index());
|
||||||
|
devices_per_group *= axis_size;
|
||||||
|
}
|
||||||
|
return devices_per_group;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MeshAxesReplicaGroupList::Print(Printer* printer) const {
|
||||||
|
printer->Append(ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string MeshAxesReplicaGroupList::ToString() const {
|
||||||
|
std::string rg_str = "";
|
||||||
|
// Add the axes defining the replica group, using names from the mesh.
|
||||||
|
std::vector<std::string> group_axes_str;
|
||||||
|
group_axes_str.reserve(axes_.size());
|
||||||
|
for (const AxisRef& axis : axes_) {
|
||||||
|
std::string axis_str = axis.ToString(mesh_);
|
||||||
|
group_axes_str.push_back(axis_str);
|
||||||
|
}
|
||||||
|
absl::StrAppend(&rg_str, mesh_.ToString(), " {",
|
||||||
|
absl::StrJoin(group_axes_str, ","), "}");
|
||||||
|
return rg_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
MeshAxesReplicaGroupListProto MeshAxesReplicaGroupList::ToProto() const {
|
||||||
|
MeshAxesReplicaGroupListProto proto;
|
||||||
|
*proto.mutable_mesh() = mesh_.ToProto();
|
||||||
|
for (const AxisRef& axis : axes_) {
|
||||||
|
*proto.add_axes() = axis.ToProto();
|
||||||
|
}
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
MeshAxesReplicaGroupList MeshAxesReplicaGroupList::FromProto(
|
||||||
|
const MeshAxesReplicaGroupListProto& proto) {
|
||||||
|
Mesh mesh = Mesh::FromProto(proto.mesh());
|
||||||
|
std::vector<AxisRef> axes;
|
||||||
|
for (const AxisRefProto& axis_proto : proto.axes()) {
|
||||||
|
axes.push_back(AxisRef::FromProto(axis_proto));
|
||||||
|
}
|
||||||
|
return MeshAxesReplicaGroupList(mesh, axes);
|
||||||
|
}
|
||||||
|
|
||||||
|
/************** IotaReplicaGroupList implementation ***************************/
|
||||||
int64_t IotaReplicaGroupList::num_replica_groups() const {
|
int64_t IotaReplicaGroupList::num_replica_groups() const {
|
||||||
DCHECK_GE(num_replica_groups_, 0);
|
DCHECK_GE(num_replica_groups_, 0);
|
||||||
return num_replica_groups_;
|
return num_replica_groups_;
|
||||||
|
|
@ -121,6 +181,7 @@ std::shared_ptr<std::vector<ReplicaGroup>> ExpandIota(
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/************** CollectiveDeviceList implementation ***************************/
|
||||||
const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
|
const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
|
||||||
if (replica_groups_ == nullptr) {
|
if (replica_groups_ == nullptr) {
|
||||||
CHECK(iota_replica_group_list_.has_value());
|
CHECK(iota_replica_group_list_.has_value());
|
||||||
|
|
|
||||||
39
third_party/xla/xla/hlo/ir/replica_group.h
vendored
39
third_party/xla/xla/hlo/ir/replica_group.h
vendored
|
|
@ -24,8 +24,11 @@ limitations under the License.
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/log/log.h"
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "google/protobuf/repeated_ptr_field.h"
|
||||||
#include "xla/array.h"
|
#include "xla/array.h"
|
||||||
|
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||||
#include "xla/hlo/ir/tile_assignment.h"
|
#include "xla/hlo/ir/tile_assignment.h"
|
||||||
#include "xla/printer.h"
|
#include "xla/printer.h"
|
||||||
#include "xla/service/hlo.pb.h"
|
#include "xla/service/hlo.pb.h"
|
||||||
|
|
@ -34,6 +37,42 @@ limitations under the License.
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
|
class MeshAxesReplicaGroupList {
|
||||||
|
public:
|
||||||
|
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes)
|
||||||
|
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
|
||||||
|
if (num_devices_per_group() == 1) {
|
||||||
|
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
|
||||||
|
<< " has only one device per replica group.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(const MeshAxesReplicaGroupList& other) const {
|
||||||
|
return mesh_ == other.mesh_ && axes_ == other.axes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename H>
|
||||||
|
friend H AbslHashValue(H h, const MeshAxesReplicaGroupList& c) {
|
||||||
|
return H::combine(std::move(h), c.mesh_, c.axes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t num_replica_groups() const;
|
||||||
|
int64_t num_devices_per_group() const;
|
||||||
|
|
||||||
|
void Print(Printer* printer) const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
|
||||||
|
MeshAxesReplicaGroupListProto ToProto() const;
|
||||||
|
|
||||||
|
static MeshAxesReplicaGroupList FromProto(
|
||||||
|
const MeshAxesReplicaGroupListProto& proto);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Mesh mesh_;
|
||||||
|
std::vector<AxisRef> axes_;
|
||||||
|
};
|
||||||
|
|
||||||
std::string ReplicaGroupsToString(
|
std::string ReplicaGroupsToString(
|
||||||
absl::Span<const ReplicaGroup> replica_groups);
|
absl::Span<const ReplicaGroup> replica_groups);
|
||||||
|
|
||||||
|
|
|
||||||
116
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
116
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
|
|
@ -16,10 +16,16 @@ limitations under the License.
|
||||||
#include "xla/hlo/ir/replica_group.h"
|
#include "xla/hlo/ir/replica_group.h"
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <gmock/gmock.h>
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "testing/base/public/mock-log.h"
|
||||||
|
#include "absl/base/log_severity.h"
|
||||||
|
#include "xla/array.h"
|
||||||
|
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||||
|
#include "xla/hlo/ir/tile_assignment.h"
|
||||||
#include "xla/service/hlo.pb.h"
|
#include "xla/service/hlo.pb.h"
|
||||||
#include "xla/xla_data.pb.h"
|
#include "xla/xla_data.pb.h"
|
||||||
|
|
||||||
|
|
@ -37,6 +43,116 @@ CollectiveDeviceListProto CreateDeviceListProto(
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
|
||||||
|
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{4, 4})),
|
||||||
|
/*axes_names=*/{"x", "y"});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_all_axes(
|
||||||
|
all_axes,
|
||||||
|
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||||
|
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
|
||||||
|
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
|
||||||
|
|
||||||
|
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{3, 5})),
|
||||||
|
/*axes_names=*/{"a", "b"});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_a(one_axes,
|
||||||
|
/*axes=*/{AxisRef(0)});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_b(one_axes,
|
||||||
|
/*axes=*/{AxisRef(1)});
|
||||||
|
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
|
||||||
|
EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3);
|
||||||
|
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
|
||||||
|
EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5);
|
||||||
|
|
||||||
|
Mesh no_axes(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{2, 3, 5})),
|
||||||
|
/*axes_names=*/{"p1", "p2", "p3"});
|
||||||
|
testing::ScopedMockLog log(testing::kDoNotCaptureLogsYet);
|
||||||
|
EXPECT_CALL(log,
|
||||||
|
Log(base_logging::ERROR, testing::_,
|
||||||
|
testing::HasSubstr("has only one device per replica group")))
|
||||||
|
.Times(1);
|
||||||
|
log.StartCapturingLogs();
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes,
|
||||||
|
/*axes=*/{});
|
||||||
|
EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5);
|
||||||
|
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
||||||
|
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{2, 6, 10})),
|
||||||
|
/*axes_names=*/{"axis1", "axis2", "axis3"});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
|
||||||
|
mesh_one_subaxis,
|
||||||
|
/*axes=*/{AxisRef(0, {1, 2})});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
|
||||||
|
mesh_one_subaxis,
|
||||||
|
/*axes=*/{AxisRef(1, {2, 3})});
|
||||||
|
EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60);
|
||||||
|
EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2);
|
||||||
|
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
|
||||||
|
EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3);
|
||||||
|
|
||||||
|
Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{2 * 3, 5 * 7, 11 * 13})),
|
||||||
|
/*axes_names=*/{"alpha", "beta", "gamma"});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
|
||||||
|
mesh_multiple_subaxis,
|
||||||
|
/*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2(
|
||||||
|
mesh_multiple_subaxis,
|
||||||
|
/*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
|
||||||
|
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(),
|
||||||
|
3 * 7 * 13);
|
||||||
|
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(),
|
||||||
|
2 * 5 * 11);
|
||||||
|
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_replica_groups(),
|
||||||
|
2 * 5 * 11);
|
||||||
|
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_devices_per_group(),
|
||||||
|
3 * 7 * 13);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||||
|
// No subaxes and iota device assignment.
|
||||||
|
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{10, 12, 15})),
|
||||||
|
/*axes_names=*/{"u", "v", "w"});
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{});
|
||||||
|
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
|
||||||
|
MeshAxesReplicaGroupList replica_group_across_uv(
|
||||||
|
mesh_uvw,
|
||||||
|
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||||
|
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
|
||||||
|
|
||||||
|
// Subaxes and replica group v2 iota style device assignment.
|
||||||
|
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||||
|
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||||
|
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||||
|
/*axes_names=*/{"a", "b", "c", "d"});
|
||||||
|
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{});
|
||||||
|
EXPECT_EQ(rg_abcd_across_none.ToString(),
|
||||||
|
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
|
||||||
|
MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes(
|
||||||
|
mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
|
||||||
|
EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(),
|
||||||
|
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {a,b:(1)2,d}");
|
||||||
|
|
||||||
|
// Subaxes and random device assignment.
|
||||||
|
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||||
|
array.Reshape({10});
|
||||||
|
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||||
|
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||||
|
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{});
|
||||||
|
EXPECT_EQ(rg_ooo_across_none.ToString(),
|
||||||
|
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
|
||||||
|
MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo,
|
||||||
|
/*axes=*/{AxisRef(0, {5, 2})});
|
||||||
|
EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(),
|
||||||
|
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CollectiveDeviceListTest, DefaultListToString) {
|
TEST(CollectiveDeviceListTest, DefaultListToString) {
|
||||||
EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}");
|
EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}");
|
||||||
EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}");
|
EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}");
|
||||||
|
|
|
||||||
13
third_party/xla/xla/xla_data.proto
vendored
13
third_party/xla/xla/xla_data.proto
vendored
|
|
@ -1192,6 +1192,19 @@ message ReplicaGroup {
|
||||||
repeated int64 replica_ids = 1;
|
repeated int64 replica_ids = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Represents a list of replica groups (a list of list of devices) via a mesh
|
||||||
|
// and list of axes. The replica groups correspond to the partitions of the
|
||||||
|
// device ids which would arise if a collective operation was performed over the
|
||||||
|
// specified axes.
|
||||||
|
message MeshAxesReplicaGroupListProto {
|
||||||
|
// The mesh used to define the full set of axes and devices ids.
|
||||||
|
MeshProto mesh = 1;
|
||||||
|
// The axes defining the replica groups. These groups are implicitly defined
|
||||||
|
// by the device ids which would communicate together if a collective
|
||||||
|
// operation is performed over these axes.
|
||||||
|
repeated AxisRefProto axes = 2;
|
||||||
|
}
|
||||||
|
|
||||||
// Represents a list of replica groups (a list of list of devices) with
|
// Represents a list of replica groups (a list of list of devices) with
|
||||||
// reshaping and transposing an iota array (iota tile assignment). Can be used
|
// reshaping and transposing an iota array (iota tile assignment). Can be used
|
||||||
// to represent certain common patterns of device lists in a compact, scalable
|
// to represent certain common patterns of device lists in a compact, scalable
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user