diff --git a/third_party/xla/xla/array.h b/third_party/xla/xla/array.h index 7d1274de103..17d4ce9a2f2 100644 --- a/third_party/xla/xla/array.h +++ b/third_party/xla/xla/array.h @@ -22,7 +22,6 @@ limitations under the License. #include #include #include -#include #include #include #include diff --git a/third_party/xla/xla/hlo/ir/BUILD b/third_party/xla/xla/hlo/ir/BUILD index a72a14aab8c..c8052df4f09 100644 --- a/third_party/xla/xla/hlo/ir/BUILD +++ b/third_party/xla/xla/hlo/ir/BUILD @@ -60,6 +60,7 @@ cc_library( deps = [ ":backend_config", ":hlo_sharding", + ":mesh_and_axis", ":named_sharding", ":ptrvec", ":tile_assignment", @@ -123,6 +124,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf_lite", "@highwayhash", "@highwayhash//:arch_specific", "@highwayhash//:hh_types", @@ -188,6 +190,7 @@ cc_library( "//xla:xla_data_proto_cc", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", ], @@ -415,9 +418,13 @@ xla_cc_test( srcs = ["replica_group_test.cc"], deps = [ ":hlo", + ":mesh_and_axis", + ":tile_assignment", + "//xla:array", "//xla:xla_data_proto_cc", "//xla/service:hlo_proto_cc", "//xla/tsl/platform:test_main", + "@com_google_absl//absl/base:log_severity", "@com_google_googletest//:gtest", ], ) @@ -433,8 +440,6 @@ xla_cc_test( "//xla/tsl/util/proto:proto_matchers", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", - "@local_tsl//tsl/platform:status_matchers", - "@local_tsl//tsl/platform:test", ], ) diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis.h b/third_party/xla/xla/hlo/ir/mesh_and_axis.h index f60f6fd0e43..feb585b9f54 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis.h +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis.h @@ -23,6 +23,8 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/xla_data.pb.h" @@ -55,6 +57,28 @@ class Mesh { axes_names_ == other.axes_names_; } + std::string ToString() const { + std::string mesh_str = "@mesh"; + // Add the mesh axes names and sizes. + std::vector formatted_axes_names; + formatted_axes_names.reserve(axes_names_.size()); + for (int64_t i = 0; i < axes_names_.size(); ++i) { + formatted_axes_names.push_back( + absl::StrCat(axes_names_[i], "=", device_assignment_.dim(i))); + } + + // Add the device assignment if it is not an iota case. + std::optional iota = device_assignment_.iota(); + std::string device_assignment_str = ""; + if (!(iota.has_value() && iota->reshape_dims().size() == 1)) { + device_assignment_str = + absl::StrCat("(", device_assignment_.ArrayToString(), ")"); + } + absl::StrAppend(&mesh_str, "<", absl::StrJoin(formatted_axes_names, ","), + ">", device_assignment_str); + return mesh_str; + } + bool operator!=(const Mesh& other) const { return !(*this == other); } MeshProto ToProto() const; @@ -62,6 +86,7 @@ class Mesh { static Mesh FromProto(const MeshProto& proto); TileAssignment device_assignment() const { return device_assignment_; } + std::vector axis_names() const { return axes_names_; } private: // Dimensions of the `device_assignment_` array correspond to the axes of the @@ -113,6 +138,17 @@ class AxisRef { return true; } + std::string ToString(const Mesh& mesh) const { + CHECK_GE(mesh_axis_index_, 0); + CHECK_LT(mesh_axis_index_, mesh.axis_names().size()); + std::string axis_str = mesh.axis_names()[mesh_axis_index()]; + if (sub_axis_info_.has_value()) { + absl::StrAppend(&axis_str, ":(", sub_axis_info_->pre_size, ")", + sub_axis_info_->size); + } + return axis_str; + } + bool operator!=(const xla::AxisRef& other) const { return !(*this == other); } AxisRefProto ToProto() const; diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc b/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc index 57f8d26b941..cb193f083a7 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc @@ -176,4 +176,23 @@ TEST(MeshAndAxisTest, MeshRoundtripProto) { EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto())); } +TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { + Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{10, 12, 15})), + /*axes_names=*/{"u", "v", "w"}); + EXPECT_EQ(mesh_uvw.ToString(), "@mesh"); + + Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, + /*transpose_perm=*/{2, 3, 0, 1})), + /*axes_names=*/{"a", "b", "c", "d"}); + EXPECT_EQ(mesh_abcd.ToString(), "@mesh([4,16]T(1,0))"); + + Array array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}}); + array.Reshape({10}); + TileAssignment tile_assignment(std::make_shared>(array)); + Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"}); + EXPECT_EQ(mesh_ooo.ToString(), "@mesh(8,3,7,5,4,2,6,0,1,9)"); +} + } // namespace xla diff --git a/third_party/xla/xla/hlo/ir/replica_group.cc b/third_party/xla/xla/hlo/ir/replica_group.cc index f693c50db3b..eba41f94d30 100644 --- a/third_party/xla/xla/hlo/ir/replica_group.cc +++ b/third_party/xla/xla/hlo/ir/replica_group.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include @@ -26,6 +27,8 @@ limitations under the License. #include "absl/strings/str_join.h" #include "absl/types/span.h" #include "xla/array.h" +#include "xla/hlo/ir/mesh_and_axis.h" +#include "xla/hlo/ir/tile_assignment.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/logging.h" // IWYU pragma: keep @@ -45,6 +48,63 @@ std::string ReplicaGroupsToString( return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}"); } +/************** MeshAxesReplicaGroupList implementation ***********************/ +int64_t MeshAxesReplicaGroupList::num_replica_groups() const { + return mesh_.device_assignment().num_elements() / num_devices_per_group(); +} + +int64_t MeshAxesReplicaGroupList::num_devices_per_group() const { + // Number of devices per replica group is equal to the product of the sizes of + // all axes. + int64_t devices_per_group = 1; + for (const AxisRef& axis : axes_) { + int64_t axis_size = + axis.sub_axis_info().has_value() + ? axis.sub_axis_info()->size + : mesh_.device_assignment().dim(axis.mesh_axis_index()); + devices_per_group *= axis_size; + } + return devices_per_group; +} + +void MeshAxesReplicaGroupList::Print(Printer* printer) const { + printer->Append(ToString()); +} + +std::string MeshAxesReplicaGroupList::ToString() const { + std::string rg_str = ""; + // Add the axes defining the replica group, using names from the mesh. + std::vector group_axes_str; + group_axes_str.reserve(axes_.size()); + for (const AxisRef& axis : axes_) { + std::string axis_str = axis.ToString(mesh_); + group_axes_str.push_back(axis_str); + } + absl::StrAppend(&rg_str, mesh_.ToString(), " {", + absl::StrJoin(group_axes_str, ","), "}"); + return rg_str; +} + +MeshAxesReplicaGroupListProto MeshAxesReplicaGroupList::ToProto() const { + MeshAxesReplicaGroupListProto proto; + *proto.mutable_mesh() = mesh_.ToProto(); + for (const AxisRef& axis : axes_) { + *proto.add_axes() = axis.ToProto(); + } + return proto; +} + +MeshAxesReplicaGroupList MeshAxesReplicaGroupList::FromProto( + const MeshAxesReplicaGroupListProto& proto) { + Mesh mesh = Mesh::FromProto(proto.mesh()); + std::vector axes; + for (const AxisRefProto& axis_proto : proto.axes()) { + axes.push_back(AxisRef::FromProto(axis_proto)); + } + return MeshAxesReplicaGroupList(mesh, axes); +} + +/************** IotaReplicaGroupList implementation ***************************/ int64_t IotaReplicaGroupList::num_replica_groups() const { DCHECK_GE(num_replica_groups_, 0); return num_replica_groups_; @@ -121,6 +181,7 @@ std::shared_ptr> ExpandIota( } } // namespace +/************** CollectiveDeviceList implementation ***************************/ const std::vector& CollectiveDeviceList::replica_groups() const { if (replica_groups_ == nullptr) { CHECK(iota_replica_group_list_.has_value()); diff --git a/third_party/xla/xla/hlo/ir/replica_group.h b/third_party/xla/xla/hlo/ir/replica_group.h index f1b612fe8c7..de8412df22a 100644 --- a/third_party/xla/xla/hlo/ir/replica_group.h +++ b/third_party/xla/xla/hlo/ir/replica_group.h @@ -24,8 +24,11 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/types/span.h" +#include "google/protobuf/repeated_ptr_field.h" #include "xla/array.h" +#include "xla/hlo/ir/mesh_and_axis.h" #include "xla/hlo/ir/tile_assignment.h" #include "xla/printer.h" #include "xla/service/hlo.pb.h" @@ -34,6 +37,42 @@ limitations under the License. namespace xla { +class MeshAxesReplicaGroupList { + public: + explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector axes) + : mesh_(std::move(mesh)), axes_(std::move(axes)) { + if (num_devices_per_group() == 1) { + LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString() + << " has only one device per replica group."; + } + } + + bool operator==(const MeshAxesReplicaGroupList& other) const { + return mesh_ == other.mesh_ && axes_ == other.axes_; + } + + template + friend H AbslHashValue(H h, const MeshAxesReplicaGroupList& c) { + return H::combine(std::move(h), c.mesh_, c.axes_); + } + + int64_t num_replica_groups() const; + int64_t num_devices_per_group() const; + + void Print(Printer* printer) const; + + std::string ToString() const; + + MeshAxesReplicaGroupListProto ToProto() const; + + static MeshAxesReplicaGroupList FromProto( + const MeshAxesReplicaGroupListProto& proto); + + private: + Mesh mesh_; + std::vector axes_; +}; + std::string ReplicaGroupsToString( absl::Span replica_groups); diff --git a/third_party/xla/xla/hlo/ir/replica_group_test.cc b/third_party/xla/xla/hlo/ir/replica_group_test.cc index 74fb67f876f..32271e07ed1 100644 --- a/third_party/xla/xla/hlo/ir/replica_group_test.cc +++ b/third_party/xla/xla/hlo/ir/replica_group_test.cc @@ -16,10 +16,16 @@ limitations under the License. #include "xla/hlo/ir/replica_group.h" #include +#include #include #include #include +#include "testing/base/public/mock-log.h" +#include "absl/base/log_severity.h" +#include "xla/array.h" +#include "xla/hlo/ir/mesh_and_axis.h" +#include "xla/hlo/ir/tile_assignment.h" #include "xla/service/hlo.pb.h" #include "xla/xla_data.pb.h" @@ -37,6 +43,116 @@ CollectiveDeviceListProto CreateDeviceListProto( return proto; } +TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) { + Mesh all_axes(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{4, 4})), + /*axes_names=*/{"x", "y"}); + MeshAxesReplicaGroupList replica_group_across_all_axes( + all_axes, + /*axes=*/{AxisRef(0), AxisRef(1)}); + EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1); + EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16); + + Mesh one_axes(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{3, 5})), + /*axes_names=*/{"a", "b"}); + MeshAxesReplicaGroupList replica_group_across_a(one_axes, + /*axes=*/{AxisRef(0)}); + MeshAxesReplicaGroupList replica_group_across_b(one_axes, + /*axes=*/{AxisRef(1)}); + EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5); + EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3); + EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3); + EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5); + + Mesh no_axes(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{2, 3, 5})), + /*axes_names=*/{"p1", "p2", "p3"}); + testing::ScopedMockLog log(testing::kDoNotCaptureLogsYet); + EXPECT_CALL(log, + Log(base_logging::ERROR, testing::_, + testing::HasSubstr("has only one device per replica group"))) + .Times(1); + log.StartCapturingLogs(); + MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes, + /*axes=*/{}); + EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5); + EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1); +} + +TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { + Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{2, 6, 10})), + /*axes_names=*/{"axis1", "axis2", "axis3"}); + MeshAxesReplicaGroupList replica_group_across_axis1_subaxis( + mesh_one_subaxis, + /*axes=*/{AxisRef(0, {1, 2})}); + MeshAxesReplicaGroupList replica_group_across_axis2_subaxis( + mesh_one_subaxis, + /*axes=*/{AxisRef(1, {2, 3})}); + EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60); + EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2); + EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40); + EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3); + + Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{2 * 3, 5 * 7, 11 * 13})), + /*axes_names=*/{"alpha", "beta", "gamma"}); + MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1( + mesh_multiple_subaxis, + /*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})}); + MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2( + mesh_multiple_subaxis, + /*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})}); + EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(), + 3 * 7 * 13); + EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(), + 2 * 5 * 11); + EXPECT_EQ(replica_group_across_multiple_subaxis2.num_replica_groups(), + 2 * 5 * 11); + EXPECT_EQ(replica_group_across_multiple_subaxis2.num_devices_per_group(), + 3 * 7 * 13); +} + +TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { + // No subaxes and iota device assignment. + Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{10, 12, 15})), + /*axes_names=*/{"u", "v", "w"}); + MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{}); + EXPECT_EQ(replica_group_across_none.ToString(), "@mesh {}"); + MeshAxesReplicaGroupList replica_group_across_uv( + mesh_uvw, + /*axes=*/{AxisRef(0), AxisRef(1)}); + EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh {u,v}"); + + // Subaxes and replica group v2 iota style device assignment. + Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create( + /*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, + /*transpose_perm=*/{2, 3, 0, 1})), + /*axes_names=*/{"a", "b", "c", "d"}); + MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{}); + EXPECT_EQ(rg_abcd_across_none.ToString(), + "@mesh([4,16]T(1,0)) {}"); + MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes( + mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)}); + EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(), + "@mesh([4,16]T(1,0)) {a,b:(1)2,d}"); + + // Subaxes and random device assignment. + Array array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}}); + array.Reshape({10}); + TileAssignment tile_assignment(std::make_shared>(array)); + Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"}); + MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{}); + EXPECT_EQ(rg_ooo_across_none.ToString(), + "@mesh(8,3,7,5,4,2,6,0,1,9) {}"); + MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo, + /*axes=*/{AxisRef(0, {5, 2})}); + EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(), + "@mesh(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}"); +} + TEST(CollectiveDeviceListTest, DefaultListToString) { EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}"); EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}"); diff --git a/third_party/xla/xla/xla_data.proto b/third_party/xla/xla/xla_data.proto index 42fba88b23e..9557c952620 100644 --- a/third_party/xla/xla/xla_data.proto +++ b/third_party/xla/xla/xla_data.proto @@ -1192,6 +1192,19 @@ message ReplicaGroup { repeated int64 replica_ids = 1; } +// Represents a list of replica groups (a list of list of devices) via a mesh +// and list of axes. The replica groups correspond to the partitions of the +// device ids which would arise if a collective operation was performed over the +// specified axes. +message MeshAxesReplicaGroupListProto { + // The mesh used to define the full set of axes and devices ids. + MeshProto mesh = 1; + // The axes defining the replica groups. These groups are implicitly defined + // by the device ids which would communicate together if a collective + // operation is performed over these axes. + repeated AxisRefProto axes = 2; +} + // Represents a list of replica groups (a list of list of devices) with // reshaping and transposing an iota array (iota tile assignment). Can be used // to represent certain common patterns of device lists in a compact, scalable