[ReplicaGroupV3][MeshAxesReplicaGroupList][1/2] Add initial class definition for V3 replica group.

PiperOrigin-RevId: 826334561
This commit is contained in:
Bill Varcho 2025-10-30 22:59:12 -07:00 committed by TensorFlower Gardener
parent d9c76aafeb
commit cef240807a
8 changed files with 291 additions and 3 deletions

View File

@ -22,7 +22,6 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <functional> #include <functional>
#include <initializer_list>
#include <iterator> #include <iterator>
#include <limits> #include <limits>
#include <memory> #include <memory>

View File

@ -60,6 +60,7 @@ cc_library(
deps = [ deps = [
":backend_config", ":backend_config",
":hlo_sharding", ":hlo_sharding",
":mesh_and_axis",
":named_sharding", ":named_sharding",
":ptrvec", ":ptrvec",
":tile_assignment", ":tile_assignment",
@ -123,6 +124,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization", "@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf_lite",
"@highwayhash", "@highwayhash",
"@highwayhash//:arch_specific", "@highwayhash//:arch_specific",
"@highwayhash//:hh_types", "@highwayhash//:hh_types",
@ -188,6 +190,7 @@ cc_library(
"//xla:xla_data_proto_cc", "//xla:xla_data_proto_cc",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support", "@llvm-project//llvm:Support",
], ],
@ -415,9 +418,13 @@ xla_cc_test(
srcs = ["replica_group_test.cc"], srcs = ["replica_group_test.cc"],
deps = [ deps = [
":hlo", ":hlo",
":mesh_and_axis",
":tile_assignment",
"//xla:array",
"//xla:xla_data_proto_cc", "//xla:xla_data_proto_cc",
"//xla/service:hlo_proto_cc", "//xla/service:hlo_proto_cc",
"//xla/tsl/platform:test_main", "//xla/tsl/platform:test_main",
"@com_google_absl//absl/base:log_severity",
"@com_google_googletest//:gtest", "@com_google_googletest//:gtest",
], ],
) )
@ -433,8 +440,6 @@ xla_cc_test(
"//xla/tsl/util/proto:proto_matchers", "//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
"@local_tsl//tsl/platform:status_matchers",
"@local_tsl//tsl/platform:test",
], ],
) )

View File

@ -23,6 +23,8 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/log/check.h" #include "absl/log/check.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/ir/tile_assignment.h"
#include "xla/xla_data.pb.h" #include "xla/xla_data.pb.h"
@ -55,6 +57,28 @@ class Mesh {
axes_names_ == other.axes_names_; axes_names_ == other.axes_names_;
} }
std::string ToString() const {
std::string mesh_str = "@mesh";
// Add the mesh axes names and sizes.
std::vector<std::string> formatted_axes_names;
formatted_axes_names.reserve(axes_names_.size());
for (int64_t i = 0; i < axes_names_.size(); ++i) {
formatted_axes_names.push_back(
absl::StrCat(axes_names_[i], "=", device_assignment_.dim(i)));
}
// Add the device assignment if it is not an iota case.
std::optional<IotaTileAssignment> iota = device_assignment_.iota();
std::string device_assignment_str = "";
if (!(iota.has_value() && iota->reshape_dims().size() == 1)) {
device_assignment_str =
absl::StrCat("(", device_assignment_.ArrayToString(), ")");
}
absl::StrAppend(&mesh_str, "<", absl::StrJoin(formatted_axes_names, ","),
">", device_assignment_str);
return mesh_str;
}
bool operator!=(const Mesh& other) const { return !(*this == other); } bool operator!=(const Mesh& other) const { return !(*this == other); }
MeshProto ToProto() const; MeshProto ToProto() const;
@ -62,6 +86,7 @@ class Mesh {
static Mesh FromProto(const MeshProto& proto); static Mesh FromProto(const MeshProto& proto);
TileAssignment device_assignment() const { return device_assignment_; } TileAssignment device_assignment() const { return device_assignment_; }
std::vector<std::string> axis_names() const { return axes_names_; }
private: private:
// Dimensions of the `device_assignment_` array correspond to the axes of the // Dimensions of the `device_assignment_` array correspond to the axes of the
@ -113,6 +138,17 @@ class AxisRef {
return true; return true;
} }
std::string ToString(const Mesh& mesh) const {
CHECK_GE(mesh_axis_index_, 0);
CHECK_LT(mesh_axis_index_, mesh.axis_names().size());
std::string axis_str = mesh.axis_names()[mesh_axis_index()];
if (sub_axis_info_.has_value()) {
absl::StrAppend(&axis_str, ":(", sub_axis_info_->pre_size, ")",
sub_axis_info_->size);
}
return axis_str;
}
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); } bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
AxisRefProto ToProto() const; AxisRefProto ToProto() const;

View File

@ -176,4 +176,23 @@ TEST(MeshAndAxisTest, MeshRoundtripProto) {
EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto())); EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto()));
} }
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{10, 12, 15})),
/*axes_names=*/{"u", "v", "w"});
EXPECT_EQ(mesh_uvw.ToString(), "@mesh<u=10,v=12,w=15>");
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1})),
/*axes_names=*/{"a", "b", "c", "d"});
EXPECT_EQ(mesh_abcd.ToString(), "@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0))");
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
array.Reshape({10});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
EXPECT_EQ(mesh_ooo.ToString(), "@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9)");
}
} // namespace xla } // namespace xla

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <optional>
#include <string> #include <string>
#include <vector> #include <vector>
@ -26,6 +27,8 @@ limitations under the License.
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "xla/array.h" #include "xla/array.h"
#include "xla/hlo/ir/mesh_and_axis.h"
#include "xla/hlo/ir/tile_assignment.h"
#include "xla/printer.h" #include "xla/printer.h"
#include "xla/service/hlo.pb.h" #include "xla/service/hlo.pb.h"
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep #include "xla/tsl/platform/logging.h" // IWYU pragma: keep
@ -45,6 +48,63 @@ std::string ReplicaGroupsToString(
return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}"); return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}");
} }
/************** MeshAxesReplicaGroupList implementation ***********************/
int64_t MeshAxesReplicaGroupList::num_replica_groups() const {
return mesh_.device_assignment().num_elements() / num_devices_per_group();
}
int64_t MeshAxesReplicaGroupList::num_devices_per_group() const {
// Number of devices per replica group is equal to the product of the sizes of
// all axes.
int64_t devices_per_group = 1;
for (const AxisRef& axis : axes_) {
int64_t axis_size =
axis.sub_axis_info().has_value()
? axis.sub_axis_info()->size
: mesh_.device_assignment().dim(axis.mesh_axis_index());
devices_per_group *= axis_size;
}
return devices_per_group;
}
void MeshAxesReplicaGroupList::Print(Printer* printer) const {
printer->Append(ToString());
}
std::string MeshAxesReplicaGroupList::ToString() const {
std::string rg_str = "";
// Add the axes defining the replica group, using names from the mesh.
std::vector<std::string> group_axes_str;
group_axes_str.reserve(axes_.size());
for (const AxisRef& axis : axes_) {
std::string axis_str = axis.ToString(mesh_);
group_axes_str.push_back(axis_str);
}
absl::StrAppend(&rg_str, mesh_.ToString(), " {",
absl::StrJoin(group_axes_str, ","), "}");
return rg_str;
}
MeshAxesReplicaGroupListProto MeshAxesReplicaGroupList::ToProto() const {
MeshAxesReplicaGroupListProto proto;
*proto.mutable_mesh() = mesh_.ToProto();
for (const AxisRef& axis : axes_) {
*proto.add_axes() = axis.ToProto();
}
return proto;
}
MeshAxesReplicaGroupList MeshAxesReplicaGroupList::FromProto(
const MeshAxesReplicaGroupListProto& proto) {
Mesh mesh = Mesh::FromProto(proto.mesh());
std::vector<AxisRef> axes;
for (const AxisRefProto& axis_proto : proto.axes()) {
axes.push_back(AxisRef::FromProto(axis_proto));
}
return MeshAxesReplicaGroupList(mesh, axes);
}
/************** IotaReplicaGroupList implementation ***************************/
int64_t IotaReplicaGroupList::num_replica_groups() const { int64_t IotaReplicaGroupList::num_replica_groups() const {
DCHECK_GE(num_replica_groups_, 0); DCHECK_GE(num_replica_groups_, 0);
return num_replica_groups_; return num_replica_groups_;
@ -121,6 +181,7 @@ std::shared_ptr<std::vector<ReplicaGroup>> ExpandIota(
} }
} // namespace } // namespace
/************** CollectiveDeviceList implementation ***************************/
const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const { const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
if (replica_groups_ == nullptr) { if (replica_groups_ == nullptr) {
CHECK(iota_replica_group_list_.has_value()); CHECK(iota_replica_group_list_.has_value());

View File

@ -24,8 +24,11 @@ limitations under the License.
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/log/log.h"
#include "absl/types/span.h" #include "absl/types/span.h"
#include "google/protobuf/repeated_ptr_field.h"
#include "xla/array.h" #include "xla/array.h"
#include "xla/hlo/ir/mesh_and_axis.h"
#include "xla/hlo/ir/tile_assignment.h" #include "xla/hlo/ir/tile_assignment.h"
#include "xla/printer.h" #include "xla/printer.h"
#include "xla/service/hlo.pb.h" #include "xla/service/hlo.pb.h"
@ -34,6 +37,42 @@ limitations under the License.
namespace xla { namespace xla {
class MeshAxesReplicaGroupList {
public:
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes)
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
if (num_devices_per_group() == 1) {
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
<< " has only one device per replica group.";
}
}
bool operator==(const MeshAxesReplicaGroupList& other) const {
return mesh_ == other.mesh_ && axes_ == other.axes_;
}
template <typename H>
friend H AbslHashValue(H h, const MeshAxesReplicaGroupList& c) {
return H::combine(std::move(h), c.mesh_, c.axes_);
}
int64_t num_replica_groups() const;
int64_t num_devices_per_group() const;
void Print(Printer* printer) const;
std::string ToString() const;
MeshAxesReplicaGroupListProto ToProto() const;
static MeshAxesReplicaGroupList FromProto(
const MeshAxesReplicaGroupListProto& proto);
private:
Mesh mesh_;
std::vector<AxisRef> axes_;
};
std::string ReplicaGroupsToString( std::string ReplicaGroupsToString(
absl::Span<const ReplicaGroup> replica_groups); absl::Span<const ReplicaGroup> replica_groups);

View File

@ -16,10 +16,16 @@ limitations under the License.
#include "xla/hlo/ir/replica_group.h" #include "xla/hlo/ir/replica_group.h"
#include <cstdint> #include <cstdint>
#include <memory>
#include <vector> #include <vector>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "testing/base/public/mock-log.h"
#include "absl/base/log_severity.h"
#include "xla/array.h"
#include "xla/hlo/ir/mesh_and_axis.h"
#include "xla/hlo/ir/tile_assignment.h"
#include "xla/service/hlo.pb.h" #include "xla/service/hlo.pb.h"
#include "xla/xla_data.pb.h" #include "xla/xla_data.pb.h"
@ -37,6 +43,116 @@ CollectiveDeviceListProto CreateDeviceListProto(
return proto; return proto;
} }
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{4, 4})),
/*axes_names=*/{"x", "y"});
MeshAxesReplicaGroupList replica_group_across_all_axes(
all_axes,
/*axes=*/{AxisRef(0), AxisRef(1)});
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{3, 5})),
/*axes_names=*/{"a", "b"});
MeshAxesReplicaGroupList replica_group_across_a(one_axes,
/*axes=*/{AxisRef(0)});
MeshAxesReplicaGroupList replica_group_across_b(one_axes,
/*axes=*/{AxisRef(1)});
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3);
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5);
Mesh no_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 3, 5})),
/*axes_names=*/{"p1", "p2", "p3"});
testing::ScopedMockLog log(testing::kDoNotCaptureLogsYet);
EXPECT_CALL(log,
Log(base_logging::ERROR, testing::_,
testing::HasSubstr("has only one device per replica group")))
.Times(1);
log.StartCapturingLogs();
MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes,
/*axes=*/{});
EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5);
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
}
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 6, 10})),
/*axes_names=*/{"axis1", "axis2", "axis3"});
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
mesh_one_subaxis,
/*axes=*/{AxisRef(0, {1, 2})});
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
mesh_one_subaxis,
/*axes=*/{AxisRef(1, {2, 3})});
EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60);
EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2);
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3);
Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2 * 3, 5 * 7, 11 * 13})),
/*axes_names=*/{"alpha", "beta", "gamma"});
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
mesh_multiple_subaxis,
/*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2(
mesh_multiple_subaxis,
/*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(),
3 * 7 * 13);
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(),
2 * 5 * 11);
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_replica_groups(),
2 * 5 * 11);
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_devices_per_group(),
3 * 7 * 13);
}
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
// No subaxes and iota device assignment.
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{10, 12, 15})),
/*axes_names=*/{"u", "v", "w"});
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{});
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
MeshAxesReplicaGroupList replica_group_across_uv(
mesh_uvw,
/*axes=*/{AxisRef(0), AxisRef(1)});
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
// Subaxes and replica group v2 iota style device assignment.
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1})),
/*axes_names=*/{"a", "b", "c", "d"});
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{});
EXPECT_EQ(rg_abcd_across_none.ToString(),
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes(
mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(),
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {a,b:(1)2,d}");
// Subaxes and random device assignment.
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
array.Reshape({10});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{});
EXPECT_EQ(rg_ooo_across_none.ToString(),
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo,
/*axes=*/{AxisRef(0, {5, 2})});
EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(),
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}");
}
TEST(CollectiveDeviceListTest, DefaultListToString) { TEST(CollectiveDeviceListTest, DefaultListToString) {
EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}"); EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}");
EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}"); EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}");

View File

@ -1192,6 +1192,19 @@ message ReplicaGroup {
repeated int64 replica_ids = 1; repeated int64 replica_ids = 1;
} }
// Represents a list of replica groups (a list of list of devices) via a mesh
// and list of axes. The replica groups correspond to the partitions of the
// device ids which would arise if a collective operation was performed over the
// specified axes.
message MeshAxesReplicaGroupListProto {
// The mesh used to define the full set of axes and devices ids.
MeshProto mesh = 1;
// The axes defining the replica groups. These groups are implicitly defined
// by the device ids which would communicate together if a collective
// operation is performed over these axes.
repeated AxisRefProto axes = 2;
}
// Represents a list of replica groups (a list of list of devices) with // Represents a list of replica groups (a list of list of devices) with
// reshaping and transposing an iota array (iota tile assignment). Can be used // reshaping and transposing an iota array (iota tile assignment). Can be used
// to represent certain common patterns of device lists in a compact, scalable // to represent certain common patterns of device lists in a compact, scalable