mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add PjRtDeviceDimensions struct and proto.
PiperOrigin-RevId: 820440467
This commit is contained in:
parent
c986bf166c
commit
c3ce8a9881
28
third_party/xla/xla/pjrt/BUILD
vendored
28
third_party/xla/xla/pjrt/BUILD
vendored
|
|
@ -369,6 +369,34 @@ cc_library(
|
|||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pjrt_device_dimensions",
|
||||
srcs = ["pjrt_device_dimensions.cc"],
|
||||
hdrs = ["pjrt_device_dimensions.h"],
|
||||
visibility = internal_visibility(["//xla:friends"]),
|
||||
deps = [
|
||||
"//xla/pjrt/proto:pjrt_device_dimensions_proto_cc",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "pjrt_device_dimensions_test",
|
||||
srcs = ["pjrt_device_dimensions_test.cc"],
|
||||
deps = [
|
||||
":pjrt_device_dimensions",
|
||||
"//xla/pjrt/proto:pjrt_device_dimensions_proto_cc",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pjrt_executable",
|
||||
srcs = ["pjrt_executable.cc"],
|
||||
|
|
|
|||
83
third_party/xla/xla/pjrt/pjrt_device_dimensions.cc
vendored
Normal file
83
third_party/xla/xla/pjrt/pjrt_device_dimensions.cc
vendored
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/pjrt/pjrt_device_dimensions.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
PjRtDeviceDimensionsProto PjRtDeviceDimensions::ToProto() const {
|
||||
PjRtDeviceDimensionsProto proto;
|
||||
for (int32_t dim : dimensions_) {
|
||||
proto.add_dimensions(dim);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::string PjRtDeviceDimensions::ToString(absl::string_view sep) const {
|
||||
return absl::StrJoin(dimensions_, sep);
|
||||
}
|
||||
|
||||
absl::StatusOr<PjRtDeviceDimensions> PjRtDeviceDimensions::FromString(
|
||||
absl::string_view text) {
|
||||
if (text.empty()) {
|
||||
return PjRtDeviceDimensions({});
|
||||
}
|
||||
std::vector<std::string> bounds_str = absl::StrSplit(text, ',');
|
||||
|
||||
absl::InlinedVector<int32_t, 3> dims;
|
||||
for (auto const& b : bounds_str) {
|
||||
int32_t bound;
|
||||
if (!absl::SimpleAtoi(b, &bound)) {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrFormat("Number parsing error for pjrt device dimensions %s "
|
||||
"while parsing %s.",
|
||||
text, b));
|
||||
}
|
||||
dims.push_back(bound);
|
||||
}
|
||||
|
||||
return PjRtDeviceDimensions(dims);
|
||||
}
|
||||
|
||||
bool AbslParseFlag(absl::string_view text, PjRtDeviceDimensions* bounds,
|
||||
std::string* err) {
|
||||
const auto status_or_dimensions = PjRtDeviceDimensions::FromString(text);
|
||||
if (!status_or_dimensions.ok()) {
|
||||
*err = status_or_dimensions.status().ToString();
|
||||
return false;
|
||||
}
|
||||
*bounds = status_or_dimensions.value();
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string AbslUnparseFlag(PjRtDeviceDimensions bounds) {
|
||||
return bounds.ToString();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
91
third_party/xla/xla/pjrt/pjrt_device_dimensions.h
vendored
Normal file
91
third_party/xla/xla/pjrt/pjrt_device_dimensions.h
vendored
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_
|
||||
#define XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/pjrt/proto/pjrt_device_dimensions.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Represents device dimensions (e.g., mesh bounds or chip coordinates).
|
||||
class PjRtDeviceDimensions {
|
||||
public:
|
||||
PjRtDeviceDimensions() = default;
|
||||
PjRtDeviceDimensions(std::initializer_list<int32_t> dims)
|
||||
: dimensions_(dims) {}
|
||||
explicit PjRtDeviceDimensions(absl::Span<const int32_t> dims)
|
||||
: dimensions_(dims.begin(), dims.end()) {}
|
||||
|
||||
int32_t& operator[](size_t i) { return dimensions_[i]; }
|
||||
const int32_t& operator[](size_t i) const { return dimensions_[i]; }
|
||||
|
||||
size_t size() const { return dimensions_.size(); }
|
||||
|
||||
friend bool operator==(const PjRtDeviceDimensions& a,
|
||||
const PjRtDeviceDimensions& b) {
|
||||
return a.dimensions_ == b.dimensions_;
|
||||
}
|
||||
|
||||
friend bool operator!=(const PjRtDeviceDimensions& a,
|
||||
const PjRtDeviceDimensions& b) {
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os,
|
||||
const PjRtDeviceDimensions& d) {
|
||||
return os << d.ToString();
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const PjRtDeviceDimensions& c) {
|
||||
return H::combine(std::move(h), c.dimensions_);
|
||||
}
|
||||
|
||||
static absl::StatusOr<PjRtDeviceDimensions> FromProto(
|
||||
const PjRtDeviceDimensionsProto& proto) {
|
||||
return PjRtDeviceDimensions(proto.dimensions());
|
||||
}
|
||||
|
||||
PjRtDeviceDimensionsProto ToProto() const;
|
||||
|
||||
std::string ToString(absl::string_view sep = ",") const;
|
||||
|
||||
static absl::StatusOr<PjRtDeviceDimensions> FromString(
|
||||
absl::string_view text);
|
||||
|
||||
private:
|
||||
absl::InlinedVector<int32_t, 3> dimensions_;
|
||||
};
|
||||
|
||||
// Support for absl flags.
|
||||
bool AbslParseFlag(absl::string_view text, PjRtDeviceDimensions* bounds,
|
||||
std::string* err);
|
||||
std::string AbslUnparseFlag(PjRtDeviceDimensions bounds);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_PJRT_PJRT_DEVICE_DIMENSIONS_H_
|
||||
129
third_party/xla/xla/pjrt/pjrt_device_dimensions_test.cc
vendored
Normal file
129
third_party/xla/xla/pjrt/pjrt_device_dimensions_test.cc
vendored
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/pjrt/pjrt_device_dimensions.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "xla/pjrt/proto/pjrt_device_dimensions.pb.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, Equality) {
|
||||
EXPECT_EQ((PjRtDeviceDimensions{1, 2, 3}), (PjRtDeviceDimensions{1, 2, 3}));
|
||||
EXPECT_NE((PjRtDeviceDimensions{1, 2, 3}), (PjRtDeviceDimensions{1, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, Ostream) {
|
||||
std::stringstream ss;
|
||||
ss << PjRtDeviceDimensions{1, 2, 3};
|
||||
EXPECT_EQ(ss.str(), "1,2,3");
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, AbslHashValue) {
|
||||
absl::flat_hash_set<PjRtDeviceDimensions> hash_set;
|
||||
hash_set.insert({1, 2, 3});
|
||||
hash_set.insert({0, 0, 0});
|
||||
hash_set.insert({1, 2, 3}); // Inserting again should not change size
|
||||
|
||||
EXPECT_EQ(hash_set.size(), 2);
|
||||
EXPECT_TRUE(hash_set.contains({1, 2, 3}));
|
||||
EXPECT_TRUE(hash_set.contains({0, 0, 0}));
|
||||
EXPECT_FALSE(hash_set.contains({1, 2, 4}));
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, FromProto) {
|
||||
PjRtDeviceDimensionsProto proto;
|
||||
proto.add_dimensions(1);
|
||||
proto.add_dimensions(2);
|
||||
proto.add_dimensions(3);
|
||||
TF_ASSERT_OK_AND_ASSIGN(PjRtDeviceDimensions dims,
|
||||
PjRtDeviceDimensions::FromProto(proto));
|
||||
EXPECT_EQ(dims, PjRtDeviceDimensions({1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, ToProto) {
|
||||
PjRtDeviceDimensions bounds = {1, 2, 3};
|
||||
PjRtDeviceDimensionsProto proto = bounds.ToProto();
|
||||
EXPECT_THAT(proto.dimensions(), testing::ElementsAre(1, 2, 3));
|
||||
}
|
||||
|
||||
TEST(AbslParseFlagTest, ValidInputs) {
|
||||
PjRtDeviceDimensions bounds;
|
||||
std::string err;
|
||||
|
||||
EXPECT_TRUE(AbslParseFlag("1,2,3", &bounds, &err));
|
||||
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2, 3}));
|
||||
EXPECT_EQ(err, "");
|
||||
|
||||
EXPECT_TRUE(AbslParseFlag("1,2", &bounds, &err));
|
||||
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2}));
|
||||
EXPECT_EQ(err, "");
|
||||
|
||||
EXPECT_TRUE(AbslParseFlag("1,2,3,4", &bounds, &err));
|
||||
EXPECT_EQ(bounds, (PjRtDeviceDimensions{1, 2, 3, 4}));
|
||||
EXPECT_EQ(err, "");
|
||||
|
||||
EXPECT_TRUE(AbslParseFlag("", &bounds, &err));
|
||||
EXPECT_EQ(bounds, (PjRtDeviceDimensions{}));
|
||||
EXPECT_EQ(err, "");
|
||||
}
|
||||
|
||||
TEST(AbslParseFlagTest, InvalidInputs) {
|
||||
PjRtDeviceDimensions bounds;
|
||||
std::string err;
|
||||
|
||||
EXPECT_FALSE(AbslParseFlag("1,a,3", &bounds, &err));
|
||||
EXPECT_THAT(err, HasSubstr("Number parsing error"));
|
||||
|
||||
EXPECT_FALSE(AbslParseFlag("1,2.5,3", &bounds, &err));
|
||||
EXPECT_THAT(err, HasSubstr("Number parsing error"));
|
||||
}
|
||||
|
||||
TEST(AbslUnparseFlagTest, ConvertsCorrectly) {
|
||||
EXPECT_EQ(AbslUnparseFlag(PjRtDeviceDimensions{1, 2, 3}), "1,2,3");
|
||||
EXPECT_EQ(AbslUnparseFlag(PjRtDeviceDimensions{0, 0, 0}), "0,0,0");
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, SubscriptAccess) {
|
||||
PjRtDeviceDimensions dims = {10, 20, 30};
|
||||
EXPECT_EQ(dims[0], 10);
|
||||
EXPECT_EQ(dims[1], 20);
|
||||
EXPECT_EQ(dims[2], 30);
|
||||
|
||||
dims[1] = 25;
|
||||
EXPECT_EQ(dims[1], 25);
|
||||
EXPECT_EQ(dims, (PjRtDeviceDimensions{10, 25, 30}));
|
||||
|
||||
const PjRtDeviceDimensions const_dims = {1, 2, 3};
|
||||
EXPECT_EQ(const_dims[0], 1);
|
||||
}
|
||||
|
||||
TEST(PjRtDeviceDimensionsTest, Size) {
|
||||
EXPECT_EQ((PjRtDeviceDimensions{1, 2, 3}).size(), 3);
|
||||
EXPECT_EQ((PjRtDeviceDimensions{1, 2}).size(), 2);
|
||||
EXPECT_EQ((PjRtDeviceDimensions{}).size(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
7
third_party/xla/xla/pjrt/proto/BUILD
vendored
7
third_party/xla/xla/pjrt/proto/BUILD
vendored
|
|
@ -41,6 +41,13 @@ tf_proto_library(
|
|||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "pjrt_device_dimensions_proto",
|
||||
srcs = ["pjrt_device_dimensions.proto"],
|
||||
compatible_with = (get_compatible_with_libtpu_portable() + get_compatible_with_portable()),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "pjrt_value_type_proto",
|
||||
srcs = ["pjrt_value_type.proto"],
|
||||
|
|
|
|||
10
third_party/xla/xla/pjrt/proto/pjrt_device_dimensions.proto
vendored
Normal file
10
third_party/xla/xla/pjrt/proto/pjrt_device_dimensions.proto
vendored
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
edition = "2023";
|
||||
|
||||
package xla;
|
||||
|
||||
option features.field_presence = IMPLICIT;
|
||||
option java_multiple_files = true;
|
||||
|
||||
message PjRtDeviceDimensionsProto {
|
||||
repeated int32 dimensions = 1;
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user