diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis.h b/third_party/xla/xla/hlo/ir/mesh_and_axis.h index 5dc8c9651db..effba677d0f 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis.h +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis.h @@ -75,6 +75,8 @@ class Mesh { axes_names_ == other.axes_names_; } + bool operator!=(const Mesh& other) const { return !(*this == other); } + std::string ToString() const { std::string mesh_str = "@mesh"; // Add the mesh axes names and sizes. @@ -97,8 +99,6 @@ class Mesh { return mesh_str; } - bool operator!=(const Mesh& other) const { return !(*this == other); } - bool DeviceAssignmentEquals(const Mesh& other) const { return device_assignment_ == other.device_assignment_; } @@ -167,6 +167,8 @@ class AxisRef { return true; } + bool operator!=(const xla::AxisRef& other) const { return !(*this == other); } + std::string ToString(const Mesh& mesh) const { CHECK_GE(mesh_axis_index_, 0); CHECK_LT(mesh_axis_index_, mesh.axis_names().size()); @@ -178,8 +180,6 @@ class AxisRef { return axis_str; } - bool operator!=(const xla::AxisRef& other) const { return !(*this == other); } - AxisRefProto ToProto() const; static AxisRef FromProto(const AxisRefProto& proto); diff --git a/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc b/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc index 8038d6f18da..b99e0bb8f9b 100644 --- a/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc +++ b/third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/hlo/ir/mesh_and_axis.h" #include -#include #include #include @@ -123,9 +122,8 @@ TEST(MeshAndAxisTest, MeshToProtoIotaTilingWithReshapeDims) { std::vector axes_names = {"axis1", "axis2", "axis3"}; EXPECT_THAT( - Mesh(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{4, 4, 1}, - /*reshape_dims=*/{4, 2, 2}, /*transpose_perm=*/{1, 0, 2})), + Mesh(TileAssignment(/*dims=*/{4, 4, 1}, /*reshape_dims=*/{4, 2, 2}, + /*transpose_perm=*/{1, 0, 2}), axes_names) .ToProto(), EqualsProto(expected)); @@ -180,15 +178,15 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"}); EXPECT_EQ(mesh_uvw.ToString(), "@mesh"); - Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, - /*transpose_perm=*/{2, 3, 0, 1})), - /*axes_names=*/{"a", "b", "c", "d"}); + Mesh mesh_abcd( + TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, + /*transpose_perm=*/{2, 3, 0, 1}), + {"a", "b", "c", "d"}); EXPECT_EQ(mesh_abcd.ToString(), "@mesh([4,16]T(1,0))"); Array array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}}); array.Reshape({10}); - Mesh mesh_ooo(array, /*axes_names=*/{"ooo"}); + Mesh mesh_ooo(array, {"ooo"}); EXPECT_EQ(mesh_ooo.ToString(), "@mesh(8,3,7,5,4,2,6,0,1,9)"); } diff --git a/third_party/xla/xla/hlo/ir/replica_group.h b/third_party/xla/xla/hlo/ir/replica_group.h index 06497aa0ddf..d56902d10b5 100644 --- a/third_party/xla/xla/hlo/ir/replica_group.h +++ b/third_party/xla/xla/hlo/ir/replica_group.h @@ -20,16 +20,12 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/log/check.h" -#include "absl/log/log.h" #include "absl/types/span.h" #include "google/protobuf/repeated_ptr_field.h" #include "xla/array.h" diff --git a/third_party/xla/xla/hlo/ir/replica_group_test.cc b/third_party/xla/xla/hlo/ir/replica_group_test.cc index 9ae886a8e92..2815984412c 100644 --- a/third_party/xla/xla/hlo/ir/replica_group_test.cc +++ b/third_party/xla/xla/hlo/ir/replica_group_test.cc @@ -43,7 +43,7 @@ CollectiveDeviceListProto CreateDeviceListProto( } TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) { - Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"}); + Mesh mesh_xy({2, 2}, {"x", "y"}); MeshAxesReplicaGroupList replica_group_none(mesh_xy, {}); std::vector> expected_replica_groups_none = { @@ -70,8 +70,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) { } TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) { - Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})), - /*axes_names=*/{"a", "b"}); + Mesh mesh({6, 6}, {"a", "b"}); // a:(1)2 MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})}); @@ -172,8 +171,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) { } TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) { - Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})), - /*axes_names=*/{"a"}); + Mesh mesh({8}, {"a"}); // a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0) MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})}); @@ -231,7 +229,7 @@ TEST(MeshAxesReplicaGroupListTest, // Create a mesh with non-iota device ordering. Array2D array({{3, 1}, {0, 2}}); TileAssignment tile_assignment(std::make_shared>(array)); - Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"}); + Mesh mesh_xy(tile_assignment, {"x", "y"}); // Reduce along x axis. MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)}); @@ -253,17 +251,13 @@ TEST(MeshAxesReplicaGroupListTest, } TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) { - Mesh all_axes(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{4, 4})), - /*axes_names=*/{"x", "y"}); + Mesh all_axes({4, 4}, {"x", "y"}); MeshAxesReplicaGroupList replica_group_across_all_axes( all_axes, {AxisRef(0), AxisRef(1)}); EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1); EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16); - Mesh one_axes(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{3, 5})), - /*axes_names=*/{"a", "b"}); + Mesh one_axes({3, 5}, {"a", "b"}); MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)}); MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)}); EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5); @@ -271,22 +265,20 @@ TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) { EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3); EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5); - Mesh no_axes(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{2, 3, 5})), - /*axes_names=*/{"p1", "p2", "p3"}); + Mesh no_axes({2, 3, 5}, {"p1", "p2", "p3"}); MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes, /*axes=*/{}); EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5); EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1); } TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) { - Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"}); + Mesh mesh({8}, {"a"}); MeshAxesReplicaGroupList replica_group_multiple_subaxes1( mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})}); MeshAxesReplicaGroupList replica_group_multiple_subaxes2( mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})}); - Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"}); + Mesh overlap_mesh({2 * 3 * 5}, {"u"}); EXPECT_DEATH( { MeshAxesReplicaGroupList overlapping_subaxes( @@ -296,9 +288,7 @@ TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) { } TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { - Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{2, 6, 10})), - /*axes_names=*/{"axis1", "axis2", "axis3"}); + Mesh mesh_one_subaxis({2, 6, 10}, {"axis1", "axis2", "axis3"}); MeshAxesReplicaGroupList replica_group_across_axis1_subaxis( mesh_one_subaxis, {AxisRef(0, {1, 2})}); MeshAxesReplicaGroupList replica_group_across_axis2_subaxis( @@ -308,9 +298,8 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40); EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3); - Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{2 * 3, 5 * 7, 11 * 13})), - /*axes_names=*/{"alpha", "beta", "gamma"}); + Mesh mesh_multiple_subaxis({2 * 3, 5 * 7, 11 * 13}, + {"alpha", "beta", "gamma"}); MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1( mesh_multiple_subaxis, {AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})}); @@ -329,9 +318,7 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) { TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { // No subaxes and iota device assignment. - Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{10, 12, 15})), - /*axes_names=*/{"u", "v", "w"}); + Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"}); MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {}); EXPECT_EQ(replica_group_across_none.ToString(), "@mesh {}"); MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw, @@ -339,10 +326,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh {u,v}"); // Subaxes and replica group v2 iota style device assignment. - Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create( - /*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, - /*transpose_perm=*/{2, 3, 0, 1})), - /*axes_names=*/{"a", "b", "c", "d"}); + Mesh mesh_abcd( + TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16}, + /*transpose_perm=*/{2, 3, 0, 1}), + {"a", "b", "c", "d"}); MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {}); EXPECT_EQ(rg_abcd_across_none.ToString(), "@mesh([4,16]T(1,0)) {}"); @@ -355,7 +342,7 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) { Array array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}}); array.Reshape({10}); TileAssignment tile_assignment(std::make_shared>(array)); - Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"}); + Mesh mesh_ooo(tile_assignment, {"ooo"}); MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {}); EXPECT_EQ(rg_ooo_across_none.ToString(), "@mesh(8,3,7,5,4,2,6,0,1,9) {}");