Add PjRtDeviceDimensions struct and proto.

PiperOrigin-RevId: 820440467
This commit is contained in:
Haibo Huang 2025-10-16 16:30:45 -07:00 committed by TensorFlower Gardener
parent c986bf166c
commit c3ce8a9881
6 changed files with 348 additions and 0 deletions

View File

@ -369,6 +369,34 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "pjrt_device_dimensions",
srcs = ["pjrt_device_dimensions.cc"],
hdrs = ["pjrt_device_dimensions.h"],
visibility = internal_visibility(["//xla:friends"]),
deps = [
"//xla/pjrt/proto:pjrt_device_dimensions_proto_cc",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
],
)
xla_cc_test(
name = "pjrt_device_dimensions_test",
srcs = ["pjrt_device_dimensions_test.cc"],
deps = [
":pjrt_device_dimensions",
"//xla/pjrt/proto:pjrt_device_dimensions_proto_cc",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "pjrt_executable",
srcs = ["pjrt_executable.cc"],

View File

@ -0,0 +1,83 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/pjrt/pjrt_device_dimensions.h"
#include <cstdint>
#include <string>
#include <vector>
#include "absl/container/inlined_vector.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
namespace xla {
PjRtDeviceDimensionsProto PjRtDeviceDimensions::ToProto() const {
PjRtDeviceDimensionsProto proto;
for (int32_t dim : dimensions_) {
proto.add_dimensions(dim);
}
return proto;
}
std::string PjRtDeviceDimensions::ToString(absl::string_view sep) const {
return absl::StrJoin(dimensions_, sep);
}
absl::StatusOr<PjRtDeviceDimensions> PjRtDeviceDimensions::FromString(
absl::string_view text) {
if (text.empty()) {
return PjRtDeviceDimensions({});
}
std::vector<std::string> bounds_str = absl::StrSplit(text, ',');
absl::InlinedVector<int32_t, 3> dims;
for (auto const& b : bounds_str) {
int32_t bound;
if (!absl::SimpleAtoi(b, &bound)) {
return absl::InvalidArgumentError(
absl::StrFormat("Number parsing error for pjrt device dimensions %s "
"while parsing %s.",
text, b));
}
dims.push_back(bound);
}
return PjRtDeviceDimensions(dims);
}
bool AbslParseFlag(absl::string_view text, PjRtDeviceDimensions* bounds,
std::string* err) {
const auto status_or_dimensions = PjRtDeviceDimensions::FromString(text);
if (!status_or_dimensions.ok()) {
*err = status_or_dimensions.status().ToString();
return false;
}
*bounds = status_or_dimensions.value();
return true;
}
std::string AbslUnparseFlag(PjRtDeviceDimensions bounds) {
return bounds.ToString();
}
} // namespace xla

View File

@ -0,0 +1,91 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_
#define XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <ostream>
#include <string>
#include <utility>
#include "absl/container/inlined_vector.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/pjrt/proto/pjrt_device_dimensions.pb.h"
namespace xla {
// Represents device dimensions (e.g., mesh bounds or chip coordinates).
class PjRtDeviceDimensions {
public:
PjRtDeviceDimensions() = default;
PjRtDeviceDimensions(std::initializer_list<int32_t> dims)
: dimensions_(dims) {}
explicit PjRtDeviceDimensions(absl::Span<const int32_t> dims)
: dimensions_(dims.begin(), dims.end()) {}
int32_t& operator[](size_t i) { return dimensions_[i]; }
const int32_t& operator[](size_t i) const { return dimensions_[i]; }
size_t size() const { return dimensions_.size(); }
friend bool operator==(const PjRtDeviceDimensions& a,
const PjRtDeviceDimensions& b) {
return a.dimensions_ == b.dimensions_;
}
friend bool operator!=(const PjRtDeviceDimensions& a,
const PjRtDeviceDimensions& b) {
return !(a == b);
}
friend std::ostream& operator<<(std::ostream& os,
const PjRtDeviceDimensions& d) {
return os << d.ToString();
}
template <typename H>
friend H AbslHashValue(H h, const PjRtDeviceDimensions& c) {
return H::combine(std::move(h), c.dimensions_);
}
static absl::StatusOr<PjRtDeviceDimensions> FromProto(
const PjRtDeviceDimensionsProto& proto) {
return PjRtDeviceDimensions(proto.dimensions());
}
PjRtDeviceDimensionsProto ToProto() const;
std::string ToString(absl::string_view sep = ",") const;
static absl::StatusOr<PjRtDeviceDimensions> FromString(
absl::string_view text);
private:
absl::InlinedVector<int32_t, 3> dimensions_;
};
// Support for absl flags.
bool AbslParseFlag(absl::string_view text, PjRtDeviceDimensions* bounds,
std::string* err);
std::string AbslUnparseFlag(PjRtDeviceDimensions bounds);
} // namespace xla
#endif // XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_

View File

@ -0,0 +1,129 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/pjrt/pjrt_device_dimensions.h"
#include <sstream>
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/container/flat_hash_set.h"
#include "xla/pjrt/proto/pjrt_device_dimensions.pb.h"
#include "xla/tsl/platform/statusor.h"
namespace xla {
namespace {
using ::testing::HasSubstr;
TEST(PjRtDeviceDimensionsTest, Equality) {
EXPECT_EQ((PjRtDeviceDimensions{1, 2, 3}), (PjRtDeviceDimensions{1, 2, 3}));
EXPECT_NE((PjRtDeviceDimensions{1, 2, 3}), (PjRtDeviceDimensions{1, 2, 4}));
}
TEST(PjRtDeviceDimensionsTest, Ostream) {
std::stringstream ss;
ss << PjRtDeviceDimensions{1, 2, 3};
EXPECT_EQ(ss.str(), "1,2,3");
}
TEST(PjRtDeviceDimensionsTest, AbslHashValue) {
absl::flat_hash_set<PjRtDeviceDimensions> hash_set;
hash_set.insert({1, 2, 3});
hash_set.insert({0, 0, 0});
hash_set.insert({1, 2, 3}); // Inserting again should not change size
EXPECT_EQ(hash_set.size(), 2);
EXPECT_TRUE(hash_set.contains({1, 2, 3}));
EXPECT_TRUE(hash_set.contains({0, 0, 0}));
EXPECT_FALSE(hash_set.contains({1, 2, 4}));
}
TEST(PjRtDeviceDimensionsTest, FromProto) {
PjRtDeviceDimensionsProto proto;
proto.add_dimensions(1);
proto.add_dimensions(2);
proto.add_dimensions(3);
TF_ASSERT_OK_AND_ASSIGN(PjRtDeviceDimensions dims,
PjRtDeviceDimensions::FromProto(proto));
EXPECT_EQ(dims, PjRtDeviceDimensions({1, 2, 3}));
}
TEST(PjRtDeviceDimensionsTest, ToProto) {
PjRtDeviceDimensions bounds = {1, 2, 3};
PjRtDeviceDimensionsProto proto = bounds.ToProto();
EXPECT_THAT(proto.dimensions(), testing::ElementsAre(1, 2, 3));
}
TEST(AbslParseFlagTest, ValidInputs) {
PjRtDeviceDimensions bounds;
std::string err;
EXPECT_TRUE(AbslParseFlag("1,2,3", &bounds, &err));
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2, 3}));
EXPECT_EQ(err, "");
EXPECT_TRUE(AbslParseFlag("1,2", &bounds, &err));
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2}));
EXPECT_EQ(err, "");
EXPECT_TRUE(AbslParseFlag("1,2,3,4", &bounds, &err));
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2, 3, 4}));
EXPECT_EQ(err, "");
EXPECT_TRUE(AbslParseFlag("", &bounds, &err));
EXPECT_EQ(bounds, (PjRtDeviceDimensions{}));
EXPECT_EQ(err, "");
}
TEST(AbslParseFlagTest, InvalidInputs) {
PjRtDeviceDimensions bounds;
std::string err;
EXPECT_FALSE(AbslParseFlag("1,a,3", &bounds, &err));
EXPECT_THAT(err, HasSubstr("Number parsing error"));
EXPECT_FALSE(AbslParseFlag("1,2.5,3", &bounds, &err));
EXPECT_THAT(err, HasSubstr("Number parsing error"));
}
TEST(AbslUnparseFlagTest, ConvertsCorrectly) {
EXPECT_EQ(AbslUnparseFlag(PjRtDeviceDimensions{1, 2, 3}), "1,2,3");
EXPECT_EQ(AbslUnparseFlag(PjRtDeviceDimensions{0, 0, 0}), "0,0,0");
}
TEST(PjRtDeviceDimensionsTest, SubscriptAccess) {
PjRtDeviceDimensions dims = {10, 20, 30};
EXPECT_EQ(dims[0], 10);
EXPECT_EQ(dims[1], 20);
EXPECT_EQ(dims[2], 30);
dims[1] = 25;
EXPECT_EQ(dims[1], 25);
EXPECT_EQ(dims, (PjRtDeviceDimensions{10, 25, 30}));
const PjRtDeviceDimensions const_dims = {1, 2, 3};
EXPECT_EQ(const_dims[0], 1);
}
TEST(PjRtDeviceDimensionsTest, Size) {
EXPECT_EQ((PjRtDeviceDimensions{1, 2, 3}).size(), 3);
EXPECT_EQ((PjRtDeviceDimensions{1, 2}).size(), 2);
EXPECT_EQ((PjRtDeviceDimensions{}).size(), 0);
}
} // namespace
} // namespace xla

View File

@ -41,6 +41,13 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
tf_proto_library(
name = "pjrt_device_dimensions_proto",
srcs = ["pjrt_device_dimensions.proto"],
compatible_with = (get_compatible_with_libtpu_portable() + get_compatible_with_portable()),
visibility = ["//visibility:public"],
)
tf_proto_library(
name = "pjrt_value_type_proto",
srcs = ["pjrt_value_type.proto"],

View File

@ -0,0 +1,10 @@
edition = "2023";
package xla;
option features.field_presence = IMPLICIT;
option java_multiple_files = true;
message PjRtDeviceDimensionsProto {
repeated int32 dimensions = 1;
}