mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[ReplicaGroupV3][MeshAxesReplicaGroupList][1/2] Add initial class definition for V3 replica group.
PiperOrigin-RevId: 826334561
This commit is contained in:
parent
d9c76aafeb
commit
cef240807a
1
third_party/xla/xla/array.h
vendored
1
third_party/xla/xla/array.h
vendored
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <initializer_list>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
|
|
|
|||
9
third_party/xla/xla/hlo/ir/BUILD
vendored
9
third_party/xla/xla/hlo/ir/BUILD
vendored
|
|
@ -60,6 +60,7 @@ cc_library(
|
|||
deps = [
|
||||
":backend_config",
|
||||
":hlo_sharding",
|
||||
":mesh_and_axis",
|
||||
":named_sharding",
|
||||
":ptrvec",
|
||||
":tile_assignment",
|
||||
|
|
@ -123,6 +124,7 @@ cc_library(
|
|||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_protobuf//:protobuf_lite",
|
||||
"@highwayhash",
|
||||
"@highwayhash//:arch_specific",
|
||||
"@highwayhash//:hh_types",
|
||||
|
|
@ -188,6 +190,7 @@ cc_library(
|
|||
"//xla:xla_data_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
|
|
@ -415,9 +418,13 @@ xla_cc_test(
|
|||
srcs = ["replica_group_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":mesh_and_axis",
|
||||
":tile_assignment",
|
||||
"//xla:array",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/service:hlo_proto_cc",
|
||||
"//xla/tsl/platform:test_main",
|
||||
"@com_google_absl//absl/base:log_severity",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
|
@ -433,8 +440,6 @@ xla_cc_test(
|
|||
"//xla/tsl/util/proto:proto_matchers",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@local_tsl//tsl/platform:status_matchers",
|
||||
"@local_tsl//tsl/platform:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
36
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
36
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
|
|
@ -23,6 +23,8 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/hlo/ir/tile_assignment.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
|
@ -55,6 +57,28 @@ class Mesh {
|
|||
axes_names_ == other.axes_names_;
|
||||
}
|
||||
|
||||
std::string ToString() const {
|
||||
std::string mesh_str = "@mesh";
|
||||
// Add the mesh axes names and sizes.
|
||||
std::vector<std::string> formatted_axes_names;
|
||||
formatted_axes_names.reserve(axes_names_.size());
|
||||
for (int64_t i = 0; i < axes_names_.size(); ++i) {
|
||||
formatted_axes_names.push_back(
|
||||
absl::StrCat(axes_names_[i], "=", device_assignment_.dim(i)));
|
||||
}
|
||||
|
||||
// Add the device assignment if it is not an iota case.
|
||||
std::optional<IotaTileAssignment> iota = device_assignment_.iota();
|
||||
std::string device_assignment_str = "";
|
||||
if (!(iota.has_value() && iota->reshape_dims().size() == 1)) {
|
||||
device_assignment_str =
|
||||
absl::StrCat("(", device_assignment_.ArrayToString(), ")");
|
||||
}
|
||||
absl::StrAppend(&mesh_str, "<", absl::StrJoin(formatted_axes_names, ","),
|
||||
">", device_assignment_str);
|
||||
return mesh_str;
|
||||
}
|
||||
|
||||
bool operator!=(const Mesh& other) const { return !(*this == other); }
|
||||
|
||||
MeshProto ToProto() const;
|
||||
|
|
@ -62,6 +86,7 @@ class Mesh {
|
|||
static Mesh FromProto(const MeshProto& proto);
|
||||
|
||||
TileAssignment device_assignment() const { return device_assignment_; }
|
||||
std::vector<std::string> axis_names() const { return axes_names_; }
|
||||
|
||||
private:
|
||||
// Dimensions of the `device_assignment_` array correspond to the axes of the
|
||||
|
|
@ -113,6 +138,17 @@ class AxisRef {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::string ToString(const Mesh& mesh) const {
|
||||
CHECK_GE(mesh_axis_index_, 0);
|
||||
CHECK_LT(mesh_axis_index_, mesh.axis_names().size());
|
||||
std::string axis_str = mesh.axis_names()[mesh_axis_index()];
|
||||
if (sub_axis_info_.has_value()) {
|
||||
absl::StrAppend(&axis_str, ":(", sub_axis_info_->pre_size, ")",
|
||||
sub_axis_info_->size);
|
||||
}
|
||||
return axis_str;
|
||||
}
|
||||
|
||||
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
|
||||
|
||||
AxisRefProto ToProto() const;
|
||||
|
|
|
|||
19
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
19
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
|
|
@ -176,4 +176,23 @@ TEST(MeshAndAxisTest, MeshRoundtripProto) {
|
|||
EXPECT_THAT(mesh_non_iota, Mesh::FromProto(mesh_non_iota.ToProto()));
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{10, 12, 15})),
|
||||
/*axes_names=*/{"u", "v", "w"});
|
||||
EXPECT_EQ(mesh_uvw.ToString(), "@mesh<u=10,v=12,w=15>");
|
||||
|
||||
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||
/*axes_names=*/{"a", "b", "c", "d"});
|
||||
EXPECT_EQ(mesh_abcd.ToString(), "@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0))");
|
||||
|
||||
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||
array.Reshape({10});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||
EXPECT_EQ(mesh_ooo.ToString(), "@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9)");
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
61
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
61
third_party/xla/xla/hlo/ir/replica_group.cc
vendored
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -26,6 +27,8 @@ limitations under the License.
|
|||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
#include "xla/hlo/ir/tile_assignment.h"
|
||||
#include "xla/printer.h"
|
||||
#include "xla/service/hlo.pb.h"
|
||||
#include "xla/tsl/platform/logging.h" // IWYU pragma: keep
|
||||
|
|
@ -45,6 +48,63 @@ std::string ReplicaGroupsToString(
|
|||
return absl::StrCat("{", absl::StrJoin(replica_group_str, ","), "}");
|
||||
}
|
||||
|
||||
/************** MeshAxesReplicaGroupList implementation ***********************/
|
||||
int64_t MeshAxesReplicaGroupList::num_replica_groups() const {
|
||||
return mesh_.device_assignment().num_elements() / num_devices_per_group();
|
||||
}
|
||||
|
||||
int64_t MeshAxesReplicaGroupList::num_devices_per_group() const {
|
||||
// Number of devices per replica group is equal to the product of the sizes of
|
||||
// all axes.
|
||||
int64_t devices_per_group = 1;
|
||||
for (const AxisRef& axis : axes_) {
|
||||
int64_t axis_size =
|
||||
axis.sub_axis_info().has_value()
|
||||
? axis.sub_axis_info()->size
|
||||
: mesh_.device_assignment().dim(axis.mesh_axis_index());
|
||||
devices_per_group *= axis_size;
|
||||
}
|
||||
return devices_per_group;
|
||||
}
|
||||
|
||||
void MeshAxesReplicaGroupList::Print(Printer* printer) const {
|
||||
printer->Append(ToString());
|
||||
}
|
||||
|
||||
std::string MeshAxesReplicaGroupList::ToString() const {
|
||||
std::string rg_str = "";
|
||||
// Add the axes defining the replica group, using names from the mesh.
|
||||
std::vector<std::string> group_axes_str;
|
||||
group_axes_str.reserve(axes_.size());
|
||||
for (const AxisRef& axis : axes_) {
|
||||
std::string axis_str = axis.ToString(mesh_);
|
||||
group_axes_str.push_back(axis_str);
|
||||
}
|
||||
absl::StrAppend(&rg_str, mesh_.ToString(), " {",
|
||||
absl::StrJoin(group_axes_str, ","), "}");
|
||||
return rg_str;
|
||||
}
|
||||
|
||||
MeshAxesReplicaGroupListProto MeshAxesReplicaGroupList::ToProto() const {
|
||||
MeshAxesReplicaGroupListProto proto;
|
||||
*proto.mutable_mesh() = mesh_.ToProto();
|
||||
for (const AxisRef& axis : axes_) {
|
||||
*proto.add_axes() = axis.ToProto();
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
MeshAxesReplicaGroupList MeshAxesReplicaGroupList::FromProto(
|
||||
const MeshAxesReplicaGroupListProto& proto) {
|
||||
Mesh mesh = Mesh::FromProto(proto.mesh());
|
||||
std::vector<AxisRef> axes;
|
||||
for (const AxisRefProto& axis_proto : proto.axes()) {
|
||||
axes.push_back(AxisRef::FromProto(axis_proto));
|
||||
}
|
||||
return MeshAxesReplicaGroupList(mesh, axes);
|
||||
}
|
||||
|
||||
/************** IotaReplicaGroupList implementation ***************************/
|
||||
int64_t IotaReplicaGroupList::num_replica_groups() const {
|
||||
DCHECK_GE(num_replica_groups_, 0);
|
||||
return num_replica_groups_;
|
||||
|
|
@ -121,6 +181,7 @@ std::shared_ptr<std::vector<ReplicaGroup>> ExpandIota(
|
|||
}
|
||||
} // namespace
|
||||
|
||||
/************** CollectiveDeviceList implementation ***************************/
|
||||
const std::vector<ReplicaGroup>& CollectiveDeviceList::replica_groups() const {
|
||||
if (replica_groups_ == nullptr) {
|
||||
CHECK(iota_replica_group_list_.has_value());
|
||||
|
|
|
|||
39
third_party/xla/xla/hlo/ir/replica_group.h
vendored
39
third_party/xla/xla/hlo/ir/replica_group.h
vendored
|
|
@ -24,8 +24,11 @@ limitations under the License.
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "google/protobuf/repeated_ptr_field.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
#include "xla/hlo/ir/tile_assignment.h"
|
||||
#include "xla/printer.h"
|
||||
#include "xla/service/hlo.pb.h"
|
||||
|
|
@ -34,6 +37,42 @@ limitations under the License.
|
|||
|
||||
namespace xla {
|
||||
|
||||
class MeshAxesReplicaGroupList {
|
||||
public:
|
||||
explicit MeshAxesReplicaGroupList(Mesh mesh, std::vector<AxisRef> axes)
|
||||
: mesh_(std::move(mesh)), axes_(std::move(axes)) {
|
||||
if (num_devices_per_group() == 1) {
|
||||
LOG(ERROR) << "MeshAxesReplicaGroupList: " << ToString()
|
||||
<< " has only one device per replica group.";
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const MeshAxesReplicaGroupList& other) const {
|
||||
return mesh_ == other.mesh_ && axes_ == other.axes_;
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const MeshAxesReplicaGroupList& c) {
|
||||
return H::combine(std::move(h), c.mesh_, c.axes_);
|
||||
}
|
||||
|
||||
int64_t num_replica_groups() const;
|
||||
int64_t num_devices_per_group() const;
|
||||
|
||||
void Print(Printer* printer) const;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
MeshAxesReplicaGroupListProto ToProto() const;
|
||||
|
||||
static MeshAxesReplicaGroupList FromProto(
|
||||
const MeshAxesReplicaGroupListProto& proto);
|
||||
|
||||
private:
|
||||
Mesh mesh_;
|
||||
std::vector<AxisRef> axes_;
|
||||
};
|
||||
|
||||
std::string ReplicaGroupsToString(
|
||||
absl::Span<const ReplicaGroup> replica_groups);
|
||||
|
||||
|
|
|
|||
116
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
116
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
|
|
@ -16,10 +16,16 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/replica_group.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "testing/base/public/mock-log.h"
|
||||
#include "absl/base/log_severity.h"
|
||||
#include "xla/array.h"
|
||||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
#include "xla/hlo/ir/tile_assignment.h"
|
||||
#include "xla/service/hlo.pb.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
||||
|
|
@ -37,6 +43,116 @@ CollectiveDeviceListProto CreateDeviceListProto(
|
|||
return proto;
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSize) {
|
||||
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{4, 4})),
|
||||
/*axes_names=*/{"x", "y"});
|
||||
MeshAxesReplicaGroupList replica_group_across_all_axes(
|
||||
all_axes,
|
||||
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
|
||||
|
||||
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{3, 5})),
|
||||
/*axes_names=*/{"a", "b"});
|
||||
MeshAxesReplicaGroupList replica_group_across_a(one_axes,
|
||||
/*axes=*/{AxisRef(0)});
|
||||
MeshAxesReplicaGroupList replica_group_across_b(one_axes,
|
||||
/*axes=*/{AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
|
||||
EXPECT_EQ(replica_group_across_a.num_devices_per_group(), 3);
|
||||
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
|
||||
EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5);
|
||||
|
||||
Mesh no_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 3, 5})),
|
||||
/*axes_names=*/{"p1", "p2", "p3"});
|
||||
testing::ScopedMockLog log(testing::kDoNotCaptureLogsYet);
|
||||
EXPECT_CALL(log,
|
||||
Log(base_logging::ERROR, testing::_,
|
||||
testing::HasSubstr("has only one device per replica group")))
|
||||
.Times(1);
|
||||
log.StartCapturingLogs();
|
||||
MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes,
|
||||
/*axes=*/{});
|
||||
EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5);
|
||||
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
||||
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 6, 10})),
|
||||
/*axes_names=*/{"axis1", "axis2", "axis3"});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
|
||||
mesh_one_subaxis,
|
||||
/*axes=*/{AxisRef(0, {1, 2})});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
|
||||
mesh_one_subaxis,
|
||||
/*axes=*/{AxisRef(1, {2, 3})});
|
||||
EXPECT_EQ(replica_group_across_axis1_subaxis.num_replica_groups(), 60);
|
||||
EXPECT_EQ(replica_group_across_axis1_subaxis.num_devices_per_group(), 2);
|
||||
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
|
||||
EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3);
|
||||
|
||||
Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2 * 3, 5 * 7, 11 * 13})),
|
||||
/*axes_names=*/{"alpha", "beta", "gamma"});
|
||||
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
|
||||
mesh_multiple_subaxis,
|
||||
/*axes=*/{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
|
||||
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis2(
|
||||
mesh_multiple_subaxis,
|
||||
/*axes=*/{AxisRef(0, {2, 3}), AxisRef(1, {5, 7}), AxisRef(2, {11, 13})});
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_replica_groups(),
|
||||
3 * 7 * 13);
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis1.num_devices_per_group(),
|
||||
2 * 5 * 11);
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_replica_groups(),
|
||||
2 * 5 * 11);
|
||||
EXPECT_EQ(replica_group_across_multiple_subaxis2.num_devices_per_group(),
|
||||
3 * 7 * 13);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
// No subaxes and iota device assignment.
|
||||
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{10, 12, 15})),
|
||||
/*axes_names=*/{"u", "v", "w"});
|
||||
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, /*axes=*/{});
|
||||
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
|
||||
MeshAxesReplicaGroupList replica_group_across_uv(
|
||||
mesh_uvw,
|
||||
/*axes=*/{AxisRef(0), AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
|
||||
|
||||
// Subaxes and replica group v2 iota style device assignment.
|
||||
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||
/*axes_names=*/{"a", "b", "c", "d"});
|
||||
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, /*axes=*/{});
|
||||
EXPECT_EQ(rg_abcd_across_none.ToString(),
|
||||
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
|
||||
MeshAxesReplicaGroupList rg_abcd_across_multiple_axes_and_subaxes(
|
||||
mesh_abcd, /*axes=*/{AxisRef(0), AxisRef(1, {1, 2}), AxisRef(3)});
|
||||
EXPECT_EQ(rg_abcd_across_multiple_axes_and_subaxes.ToString(),
|
||||
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {a,b:(1)2,d}");
|
||||
|
||||
// Subaxes and random device assignment.
|
||||
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||
array.Reshape({10});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, /*axes=*/{});
|
||||
EXPECT_EQ(rg_ooo_across_none.ToString(),
|
||||
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
|
||||
MeshAxesReplicaGroupList rg_ooo_across_ooo_5_2(mesh_ooo,
|
||||
/*axes=*/{AxisRef(0, {5, 2})});
|
||||
EXPECT_EQ(rg_ooo_across_ooo_5_2.ToString(),
|
||||
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {ooo:(5)2}");
|
||||
}
|
||||
|
||||
TEST(CollectiveDeviceListTest, DefaultListToString) {
|
||||
EXPECT_EQ(CollectiveDeviceList().ToString(true), "{}");
|
||||
EXPECT_EQ(CollectiveDeviceList().ToString(false), "{}");
|
||||
|
|
|
|||
13
third_party/xla/xla/xla_data.proto
vendored
13
third_party/xla/xla/xla_data.proto
vendored
|
|
@ -1192,6 +1192,19 @@ message ReplicaGroup {
|
|||
repeated int64 replica_ids = 1;
|
||||
}
|
||||
|
||||
// Represents a list of replica groups (a list of list of devices) via a mesh
|
||||
// and list of axes. The replica groups correspond to the partitions of the
|
||||
// device ids which would arise if a collective operation was performed over the
|
||||
// specified axes.
|
||||
message MeshAxesReplicaGroupListProto {
|
||||
// The mesh used to define the full set of axes and devices ids.
|
||||
MeshProto mesh = 1;
|
||||
// The axes defining the replica groups. These groups are implicitly defined
|
||||
// by the device ids which would communicate together if a collective
|
||||
// operation is performed over these axes.
|
||||
repeated AxisRefProto axes = 2;
|
||||
}
|
||||
|
||||
// Represents a list of replica groups (a list of list of devices) with
|
||||
// reshaping and transposing an iota array (iota tile assignment). Can be used
|
||||
// to represent certain common patterns of device lists in a compact, scalable
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user