Refactor mesh and axis representation.

PiperOrigin-RevId: 826647907
This commit is contained in:
Zixuan Jiang 2025-10-31 15:18:53 -07:00 committed by TensorFlower Gardener
parent 9c620f90b8
commit bf84442f21
4 changed files with 29 additions and 48 deletions

View File

@ -75,6 +75,8 @@ class Mesh {
axes_names_ == other.axes_names_;
}
bool operator!=(const Mesh& other) const { return !(*this == other); }
std::string ToString() const {
std::string mesh_str = "@mesh";
// Add the mesh axes names and sizes.
@ -97,8 +99,6 @@ class Mesh {
return mesh_str;
}
bool operator!=(const Mesh& other) const { return !(*this == other); }
bool DeviceAssignmentEquals(const Mesh& other) const {
return device_assignment_ == other.device_assignment_;
}
@ -167,6 +167,8 @@ class AxisRef {
return true;
}
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
std::string ToString(const Mesh& mesh) const {
CHECK_GE(mesh_axis_index_, 0);
CHECK_LT(mesh_axis_index_, mesh.axis_names().size());
@ -178,8 +180,6 @@ class AxisRef {
return axis_str;
}
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
AxisRefProto ToProto() const;
static AxisRef FromProto(const AxisRefProto& proto);

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "xla/hlo/ir/mesh_and_axis.h"
#include <cstdint>
#include <memory>
#include <string>
#include <vector>
@ -123,9 +122,8 @@ TEST(MeshAndAxisTest, MeshToProtoIotaTilingWithReshapeDims) {
std::vector<std::string> axes_names = {"axis1", "axis2", "axis3"};
EXPECT_THAT(
Mesh(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{4, 4, 1},
/*reshape_dims=*/{4, 2, 2}, /*transpose_perm=*/{1, 0, 2})),
Mesh(TileAssignment(/*dims=*/{4, 4, 1}, /*reshape_dims=*/{4, 2, 2},
/*transpose_perm=*/{1, 0, 2}),
axes_names)
.ToProto(),
EqualsProto(expected));
@ -180,15 +178,15 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"});
EXPECT_EQ(mesh_uvw.ToString(), "@mesh<u=10,v=12,w=15>");
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1})),
/*axes_names=*/{"a", "b", "c", "d"});
Mesh mesh_abcd(
TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1}),
{"a", "b", "c", "d"});
EXPECT_EQ(mesh_abcd.ToString(), "@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0))");
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
array.Reshape({10});
Mesh mesh_ooo(array, /*axes_names=*/{"ooo"});
Mesh mesh_ooo(array, {"ooo"});
EXPECT_EQ(mesh_ooo.ToString(), "@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9)");
}

View File

@ -20,16 +20,12 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <memory>
#include <numeric>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/types/span.h"
#include "google/protobuf/repeated_ptr_field.h"
#include "xla/array.h"

View File

@ -43,7 +43,7 @@ CollectiveDeviceListProto CreateDeviceListProto(
}
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"});
Mesh mesh_xy({2, 2}, {"x", "y"});
MeshAxesReplicaGroupList replica_group_none(mesh_xy, {});
std::vector<std::vector<int64_t>> expected_replica_groups_none = {
@ -70,8 +70,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
}
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})),
/*axes_names=*/{"a", "b"});
Mesh mesh({6, 6}, {"a", "b"});
// a:(1)2
MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})});
@ -172,8 +171,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
}
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) {
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})),
/*axes_names=*/{"a"});
Mesh mesh({8}, {"a"});
// a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0)
MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})});
@ -231,7 +229,7 @@ TEST(MeshAxesReplicaGroupListTest,
// Create a mesh with non-iota device ordering.
Array2D<int64_t> array({{3, 1}, {0, 2}});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"});
Mesh mesh_xy(tile_assignment, {"x", "y"});
// Reduce along x axis.
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
@ -253,17 +251,13 @@ TEST(MeshAxesReplicaGroupListTest,
}
TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{4, 4})),
/*axes_names=*/{"x", "y"});
Mesh all_axes({4, 4}, {"x", "y"});
MeshAxesReplicaGroupList replica_group_across_all_axes(
all_axes, {AxisRef(0), AxisRef(1)});
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{3, 5})),
/*axes_names=*/{"a", "b"});
Mesh one_axes({3, 5}, {"a", "b"});
MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)});
MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)});
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
@ -271,22 +265,20 @@ TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5);
Mesh no_axes(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 3, 5})),
/*axes_names=*/{"p1", "p2", "p3"});
Mesh no_axes({2, 3, 5}, {"p1", "p2", "p3"});
MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes, /*axes=*/{});
EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5);
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
}
TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"});
Mesh mesh({8}, {"a"});
MeshAxesReplicaGroupList replica_group_multiple_subaxes1(
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
MeshAxesReplicaGroupList replica_group_multiple_subaxes2(
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"});
Mesh overlap_mesh({2 * 3 * 5}, {"u"});
EXPECT_DEATH(
{
MeshAxesReplicaGroupList overlapping_subaxes(
@ -296,9 +288,7 @@ TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
}
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 6, 10})),
/*axes_names=*/{"axis1", "axis2", "axis3"});
Mesh mesh_one_subaxis({2, 6, 10}, {"axis1", "axis2", "axis3"});
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
mesh_one_subaxis, {AxisRef(0, {1, 2})});
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
@ -308,9 +298,8 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3);
Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2 * 3, 5 * 7, 11 * 13})),
/*axes_names=*/{"alpha", "beta", "gamma"});
Mesh mesh_multiple_subaxis({2 * 3, 5 * 7, 11 * 13},
{"alpha", "beta", "gamma"});
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
mesh_multiple_subaxis,
{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
@ -329,9 +318,7 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
// No subaxes and iota device assignment.
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{10, 12, 15})),
/*axes_names=*/{"u", "v", "w"});
Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"});
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {});
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw,
@ -339,10 +326,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
// Subaxes and replica group v2 iota style device assignment.
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1})),
/*axes_names=*/{"a", "b", "c", "d"});
Mesh mesh_abcd(
TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
/*transpose_perm=*/{2, 3, 0, 1}),
{"a", "b", "c", "d"});
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {});
EXPECT_EQ(rg_abcd_across_none.ToString(),
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
@ -355,7 +342,7 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
array.Reshape({10});
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
Mesh mesh_ooo(tile_assignment, {"ooo"});
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {});
EXPECT_EQ(rg_ooo_across_none.ToString(),
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");