mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Refactor mesh and axis representation.
PiperOrigin-RevId: 826647907
This commit is contained in:
parent
9c620f90b8
commit
bf84442f21
8
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
8
third_party/xla/xla/hlo/ir/mesh_and_axis.h
vendored
|
|
@ -75,6 +75,8 @@ class Mesh {
|
|||
axes_names_ == other.axes_names_;
|
||||
}
|
||||
|
||||
bool operator!=(const Mesh& other) const { return !(*this == other); }
|
||||
|
||||
std::string ToString() const {
|
||||
std::string mesh_str = "@mesh";
|
||||
// Add the mesh axes names and sizes.
|
||||
|
|
@ -97,8 +99,6 @@ class Mesh {
|
|||
return mesh_str;
|
||||
}
|
||||
|
||||
bool operator!=(const Mesh& other) const { return !(*this == other); }
|
||||
|
||||
bool DeviceAssignmentEquals(const Mesh& other) const {
|
||||
return device_assignment_ == other.device_assignment_;
|
||||
}
|
||||
|
|
@ -167,6 +167,8 @@ class AxisRef {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
|
||||
|
||||
std::string ToString(const Mesh& mesh) const {
|
||||
CHECK_GE(mesh_axis_index_, 0);
|
||||
CHECK_LT(mesh_axis_index_, mesh.axis_names().size());
|
||||
|
|
@ -178,8 +180,6 @@ class AxisRef {
|
|||
return axis_str;
|
||||
}
|
||||
|
||||
bool operator!=(const xla::AxisRef& other) const { return !(*this == other); }
|
||||
|
||||
AxisRefProto ToProto() const;
|
||||
|
||||
static AxisRef FromProto(const AxisRefProto& proto);
|
||||
|
|
|
|||
16
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
16
third_party/xla/xla/hlo/ir/mesh_and_axis_test.cc
vendored
|
|
@ -16,7 +16,6 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/mesh_and_axis.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -123,9 +122,8 @@ TEST(MeshAndAxisTest, MeshToProtoIotaTilingWithReshapeDims) {
|
|||
|
||||
std::vector<std::string> axes_names = {"axis1", "axis2", "axis3"};
|
||||
EXPECT_THAT(
|
||||
Mesh(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{4, 4, 1},
|
||||
/*reshape_dims=*/{4, 2, 2}, /*transpose_perm=*/{1, 0, 2})),
|
||||
Mesh(TileAssignment(/*dims=*/{4, 4, 1}, /*reshape_dims=*/{4, 2, 2},
|
||||
/*transpose_perm=*/{1, 0, 2}),
|
||||
axes_names)
|
||||
.ToProto(),
|
||||
EqualsProto(expected));
|
||||
|
|
@ -180,15 +178,15 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
|||
Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"});
|
||||
EXPECT_EQ(mesh_uvw.ToString(), "@mesh<u=10,v=12,w=15>");
|
||||
|
||||
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||
/*axes_names=*/{"a", "b", "c", "d"});
|
||||
Mesh mesh_abcd(
|
||||
TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1}),
|
||||
{"a", "b", "c", "d"});
|
||||
EXPECT_EQ(mesh_abcd.ToString(), "@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0))");
|
||||
|
||||
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||
array.Reshape({10});
|
||||
Mesh mesh_ooo(array, /*axes_names=*/{"ooo"});
|
||||
Mesh mesh_ooo(array, {"ooo"});
|
||||
EXPECT_EQ(mesh_ooo.ToString(), "@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9)");
|
||||
}
|
||||
|
||||
|
|
|
|||
4
third_party/xla/xla/hlo/ir/replica_group.h
vendored
4
third_party/xla/xla/hlo/ir/replica_group.h
vendored
|
|
@ -20,16 +20,12 @@ limitations under the License.
|
|||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "google/protobuf/repeated_ptr_field.h"
|
||||
#include "xla/array.h"
|
||||
|
|
|
|||
49
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
49
third_party/xla/xla/hlo/ir/replica_group_test.cc
vendored
|
|
@ -43,7 +43,7 @@ CollectiveDeviceListProto CreateDeviceListProto(
|
|||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
|
||||
Mesh mesh_xy(TileAssignment({2, 2}), /*axes_names=*/{"x", "y"});
|
||||
Mesh mesh_xy({2, 2}, {"x", "y"});
|
||||
|
||||
MeshAxesReplicaGroupList replica_group_none(mesh_xy, {});
|
||||
std::vector<std::vector<int64_t>> expected_replica_groups_none = {
|
||||
|
|
@ -70,8 +70,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroups) {
|
|||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
|
||||
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{6, 6})),
|
||||
/*axes_names=*/{"a", "b"});
|
||||
Mesh mesh({6, 6}, {"a", "b"});
|
||||
|
||||
// a:(1)2
|
||||
MeshAxesReplicaGroupList replica_group_a_1_2(mesh, {AxisRef(0, {1, 2})});
|
||||
|
|
@ -172,8 +171,7 @@ TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsWithSubaxes) {
|
|||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, MaterializedReplicaGroupsMatchExpectedV2) {
|
||||
Mesh mesh(TileAssignment(IotaTileAssignment::Create(/*dims=*/{8})),
|
||||
/*axes_names=*/{"a"});
|
||||
Mesh mesh({8}, {"a"});
|
||||
|
||||
// a:(1)2 -> replica_groups=[4,2]<=[2,4]T(1,0)
|
||||
MeshAxesReplicaGroupList v3_subaxis_1_2(mesh, {AxisRef(0, {1, 2})});
|
||||
|
|
@ -231,7 +229,7 @@ TEST(MeshAxesReplicaGroupListTest,
|
|||
// Create a mesh with non-iota device ordering.
|
||||
Array2D<int64_t> array({{3, 1}, {0, 2}});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_xy(tile_assignment, /*axes_names=*/{"x", "y"});
|
||||
Mesh mesh_xy(tile_assignment, {"x", "y"});
|
||||
|
||||
// Reduce along x axis.
|
||||
MeshAxesReplicaGroupList replica_group_x(mesh_xy, {AxisRef(0)});
|
||||
|
|
@ -253,17 +251,13 @@ TEST(MeshAxesReplicaGroupListTest,
|
|||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
|
||||
Mesh all_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{4, 4})),
|
||||
/*axes_names=*/{"x", "y"});
|
||||
Mesh all_axes({4, 4}, {"x", "y"});
|
||||
MeshAxesReplicaGroupList replica_group_across_all_axes(
|
||||
all_axes, {AxisRef(0), AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_replica_groups(), 1);
|
||||
EXPECT_EQ(replica_group_across_all_axes.num_devices_per_group(), 16);
|
||||
|
||||
Mesh one_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{3, 5})),
|
||||
/*axes_names=*/{"a", "b"});
|
||||
Mesh one_axes({3, 5}, {"a", "b"});
|
||||
MeshAxesReplicaGroupList replica_group_across_a(one_axes, {AxisRef(0)});
|
||||
MeshAxesReplicaGroupList replica_group_across_b(one_axes, {AxisRef(1)});
|
||||
EXPECT_EQ(replica_group_across_a.num_replica_groups(), 5);
|
||||
|
|
@ -271,22 +265,20 @@ TEST(MeshAxesReplicaGroupListTest, NumReplicaGroups) {
|
|||
EXPECT_EQ(replica_group_across_b.num_replica_groups(), 3);
|
||||
EXPECT_EQ(replica_group_across_b.num_devices_per_group(), 5);
|
||||
|
||||
Mesh no_axes(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 3, 5})),
|
||||
/*axes_names=*/{"p1", "p2", "p3"});
|
||||
Mesh no_axes({2, 3, 5}, {"p1", "p2", "p3"});
|
||||
MeshAxesReplicaGroupList replica_group_across_no_axes(no_axes, /*axes=*/{});
|
||||
EXPECT_EQ(replica_group_across_no_axes.num_replica_groups(), 2 * 3 * 5);
|
||||
EXPECT_EQ(replica_group_across_no_axes.num_devices_per_group(), 1);
|
||||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
|
||||
Mesh mesh(TileAssignment({8}), /*axes_names=*/{"1"});
|
||||
Mesh mesh({8}, {"a"});
|
||||
MeshAxesReplicaGroupList replica_group_multiple_subaxes1(
|
||||
mesh, {AxisRef(0, {1, 2}), AxisRef(0, {4, 2})});
|
||||
MeshAxesReplicaGroupList replica_group_multiple_subaxes2(
|
||||
mesh, {AxisRef(0, {4, 2}), AxisRef(0, {1, 2})});
|
||||
|
||||
Mesh overlap_mesh(TileAssignment({2 * 3 * 5}), /*axes_names=*/{"u"});
|
||||
Mesh overlap_mesh({2 * 3 * 5}, {"u"});
|
||||
EXPECT_DEATH(
|
||||
{
|
||||
MeshAxesReplicaGroupList overlapping_subaxes(
|
||||
|
|
@ -296,9 +288,7 @@ TEST(MeshAxesReplicaGroupListTest, ValidateSubAxesCoexistenceCheck) {
|
|||
}
|
||||
|
||||
TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
||||
Mesh mesh_one_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 6, 10})),
|
||||
/*axes_names=*/{"axis1", "axis2", "axis3"});
|
||||
Mesh mesh_one_subaxis({2, 6, 10}, {"axis1", "axis2", "axis3"});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis1_subaxis(
|
||||
mesh_one_subaxis, {AxisRef(0, {1, 2})});
|
||||
MeshAxesReplicaGroupList replica_group_across_axis2_subaxis(
|
||||
|
|
@ -308,9 +298,8 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
|||
EXPECT_EQ(replica_group_across_axis2_subaxis.num_replica_groups(), 40);
|
||||
EXPECT_EQ(replica_group_across_axis2_subaxis.num_devices_per_group(), 3);
|
||||
|
||||
Mesh mesh_multiple_subaxis(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2 * 3, 5 * 7, 11 * 13})),
|
||||
/*axes_names=*/{"alpha", "beta", "gamma"});
|
||||
Mesh mesh_multiple_subaxis({2 * 3, 5 * 7, 11 * 13},
|
||||
{"alpha", "beta", "gamma"});
|
||||
MeshAxesReplicaGroupList replica_group_across_multiple_subaxis1(
|
||||
mesh_multiple_subaxis,
|
||||
{AxisRef(0, {1, 2}), AxisRef(1, {1, 5}), AxisRef(2, {1, 11})});
|
||||
|
|
@ -329,9 +318,7 @@ TEST(MeshAxesReplicaGroupListTest, ReplicaGroupsCountAndSizeForSubaxes) {
|
|||
|
||||
TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
||||
// No subaxes and iota device assignment.
|
||||
Mesh mesh_uvw(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{10, 12, 15})),
|
||||
/*axes_names=*/{"u", "v", "w"});
|
||||
Mesh mesh_uvw({10, 12, 15}, {"u", "v", "w"});
|
||||
MeshAxesReplicaGroupList replica_group_across_none(mesh_uvw, {});
|
||||
EXPECT_EQ(replica_group_across_none.ToString(), "@mesh<u=10,v=12,w=15> {}");
|
||||
MeshAxesReplicaGroupList replica_group_across_uv(mesh_uvw,
|
||||
|
|
@ -339,10 +326,10 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
|||
EXPECT_EQ(replica_group_across_uv.ToString(), "@mesh<u=10,v=12,w=15> {u,v}");
|
||||
|
||||
// Subaxes and replica group v2 iota style device assignment.
|
||||
Mesh mesh_abcd(TileAssignment(IotaTileAssignment::Create(
|
||||
/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1})),
|
||||
/*axes_names=*/{"a", "b", "c", "d"});
|
||||
Mesh mesh_abcd(
|
||||
TileAssignment(/*dims=*/{2, 4, 4, 2}, /*reshape_dims=*/{1, 4, 1, 16},
|
||||
/*transpose_perm=*/{2, 3, 0, 1}),
|
||||
{"a", "b", "c", "d"});
|
||||
MeshAxesReplicaGroupList rg_abcd_across_none(mesh_abcd, {});
|
||||
EXPECT_EQ(rg_abcd_across_none.ToString(),
|
||||
"@mesh<a=2,b=4,c=4,d=2>([4,16]T(1,0)) {}");
|
||||
|
|
@ -355,7 +342,7 @@ TEST(MeshAxesReplicaGroupListTest, MeshAxesToString) {
|
|||
Array<int64_t> array({{8, 3, 7, 5, 4, 2, 6, 0, 1, 9}});
|
||||
array.Reshape({10});
|
||||
TileAssignment tile_assignment(std::make_shared<Array<int64_t>>(array));
|
||||
Mesh mesh_ooo(tile_assignment, /*axes_names=*/{"ooo"});
|
||||
Mesh mesh_ooo(tile_assignment, {"ooo"});
|
||||
MeshAxesReplicaGroupList rg_ooo_across_none(mesh_ooo, {});
|
||||
EXPECT_EQ(rg_ooo_across_none.ToString(),
|
||||
"@mesh<ooo=10>(8,3,7,5,4,2,6,0,1,9) {}");
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user