Add proto serialization for RocmComputeCapability

PiperOrigin-RevId: 825103988
This commit is contained in:
Henning Becker 2025-10-28 10:32:22 -07:00 committed by TensorFlower Gardener
parent fa61547732
commit 4dfbd3bd0c
3 changed files with 60 additions and 0 deletions

View File

@ -38,6 +38,17 @@ cc_library(
],
)
xla_cc_test(
name = "rocm_compute_capability_test",
srcs = ["rocm_compute_capability_test.cc"],
deps = [
":rocm_compute_capability",
"//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "rocm_context",
srcs = ["rocm_context.cc"],

View File

@ -61,6 +61,11 @@ class RocmComputeCapability {
return proto;
}
static RocmComputeCapability FromProto(
const RocmComputeCapabilityProto& proto) {
return RocmComputeCapability{proto.gcn_arch_name()};
}
bool operator==(const RocmComputeCapability& other) const {
return gcn_arch_name_ == other.gcn_arch_name_;
}

View File

@ -0,0 +1,44 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "xla/tsl/util/proto/proto_matchers.h"
namespace stream_executor::rocm {
namespace {
constexpr absl::string_view kExampleGcnArchName = "gfx1010:xnack-";
TEST(RocmComputeCapabilityTest, FromProto) {
RocmComputeCapabilityProto proto;
proto.set_gcn_arch_name(kExampleGcnArchName);
RocmComputeCapability cc = RocmComputeCapability::FromProto(proto);
EXPECT_EQ(cc, RocmComputeCapability{std::string(kExampleGcnArchName)});
}
TEST(RocmComputeCapabilityTest, ToProto) {
RocmComputeCapability cc{std::string(kExampleGcnArchName)};
EXPECT_THAT(cc.ToProto(), tsl::proto_testing::EqualsProto(
R"pb(gcn_arch_name: "gfx1010:xnack-")pb"));
}
} // namespace
} // namespace stream_executor::rocm