Add proto serialization for GpuComputeCapability

PiperOrigin-RevId: 825657032
This commit is contained in:
Henning Becker 2025-10-29 12:20:29 -07:00 committed by TensorFlower Gardener
parent 9cbe7bd184
commit 757f0ac980
5 changed files with 73 additions and 1 deletions

View File

@ -68,6 +68,7 @@ cc_library(
"//xla/tsl/platform:statusor", "//xla/tsl/platform:statusor",
"@com_google_absl//absl/log", "@com_google_absl//absl/log",
"@com_google_absl//absl/log:check", "@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor", "@com_google_absl//absl/status:statusor",
], ],
) )
@ -924,6 +925,9 @@ xla_cc_test(
deps = [ deps = [
":device_description", ":device_description",
":semantic_version", ":semantic_version",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"@com_google_absl//absl/status:status_matchers",
"@com_google_googletest//:gtest_main", "@com_google_googletest//:gtest_main",
], ],
) )

View File

@ -17,13 +17,14 @@ limitations under the License.
#include <cstdint> #include <cstdint>
#include <string> #include <string>
#include <variant>
#include "absl/log/check.h" #include "absl/log/check.h"
#include "absl/log/log.h" #include "absl/log/log.h"
#include "absl/status/status.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.pb.h" #include "xla/stream_executor/device_description.pb.h"
#include "xla/stream_executor/launch_dim.h" #include "xla/stream_executor/launch_dim.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/tsl/lib/math/math_util.h" #include "xla/tsl/lib/math/math_util.h"
#include "xla/tsl/platform/statusor.h" #include "xla/tsl/platform/statusor.h"
@ -150,4 +151,33 @@ void CalculateDimensionality(const DeviceDescription &device_description,
} }
} }
GpuComputeCapabilityProto GpuComputeCapability::ToProto() const {
GpuComputeCapabilityProto proto;
if (IsCuda()) {
*proto.mutable_cuda_compute_capability() =
cuda_compute_capability()->ToProto();
} else {
*proto.mutable_rocm_compute_capability() =
rocm_compute_capability()->ToProto();
}
return proto;
}
absl::StatusOr<GpuComputeCapability> GpuComputeCapability::FromProto(
const GpuComputeCapabilityProto& proto) {
if (proto.has_cuda_compute_capability()) {
TF_ASSIGN_OR_RETURN(
CudaComputeCapability cuda_compute_capability,
CudaComputeCapability::FromProto(proto.cuda_compute_capability()));
return GpuComputeCapability(cuda_compute_capability);
}
if (proto.has_rocm_compute_capability()) {
return GpuComputeCapability(
RocmComputeCapability::FromProto(proto.rocm_compute_capability()));
}
return absl::InvalidArgumentError(
"The serialized GpuComputeCapability has no compute capability set.");
}
} // namespace stream_executor } // namespace stream_executor

View File

@ -78,6 +78,21 @@ class GpuComputeCapability {
return rocm_compute_capability()->ToString(); return rocm_compute_capability()->ToString();
} }
GpuComputeCapabilityProto ToProto() const;
static absl::StatusOr<GpuComputeCapability> FromProto(
const GpuComputeCapabilityProto& proto);
friend bool operator==(const GpuComputeCapability& lhs,
const GpuComputeCapability& rhs) {
return lhs.compute_capability_ == rhs.compute_capability_;
}
friend bool operator!=(const GpuComputeCapability& lhs,
const GpuComputeCapability& rhs) {
return !(lhs == rhs);
}
private: private:
std::variant<CudaComputeCapability, RocmComputeCapability> std::variant<CudaComputeCapability, RocmComputeCapability>
compute_capability_; compute_capability_;

View File

@ -24,6 +24,13 @@ message RocmComputeCapabilityProto {
string gcn_arch_name = 1; string gcn_arch_name = 1;
} }
message GpuComputeCapabilityProto {
oneof compute_capability {
CudaComputeCapabilityProto cuda_compute_capability = 1;
RocmComputeCapabilityProto rocm_compute_capability = 2;
}
}
message GpuDeviceInfoProto { message GpuDeviceInfoProto {
int32 threads_per_block_limit = 1; int32 threads_per_block_limit = 1;
int32 threads_per_warp = 2; int32 threads_per_warp = 2;

View File

@ -16,11 +16,16 @@ limitations under the License.
#include <string> #include <string>
#include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "absl/status/status_matchers.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/stream_executor/semantic_version.h" #include "xla/stream_executor/semantic_version.h"
namespace stream_executor { namespace stream_executor {
namespace { namespace {
using absl_testing::IsOkAndHolds;
TEST(DeviceDescription, DefaultConstruction) { TEST(DeviceDescription, DefaultConstruction) {
DeviceDescription desc; DeviceDescription desc;
@ -116,5 +121,16 @@ TEST(RocmComputeCapability, Accessors) {
EXPECT_TRUE(RocmComputeCapability{"gfx1103"}.has_hipblaslt()); EXPECT_TRUE(RocmComputeCapability{"gfx1103"}.has_hipblaslt());
} }
TEST(GpuComputeCapability, ProtoConversion) {
EXPECT_THAT(
GpuComputeCapability::FromProto(
GpuComputeCapability(CudaComputeCapability::Volta()).ToProto()),
IsOkAndHolds(GpuComputeCapability(CudaComputeCapability::Volta())));
EXPECT_THAT(
GpuComputeCapability::FromProto(
GpuComputeCapability(RocmComputeCapability("gfx900")).ToProto()),
IsOkAndHolds(GpuComputeCapability(RocmComputeCapability("gfx900"))));
}
} // namespace } // namespace
} // namespace stream_executor } // namespace stream_executor