mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add proto serialization for GpuComputeCapability
PiperOrigin-RevId: 825657032
This commit is contained in:
parent
9cbe7bd184
commit
757f0ac980
4
third_party/xla/xla/stream_executor/BUILD
vendored
4
third_party/xla/xla/stream_executor/BUILD
vendored
|
|
@ -68,6 +68,7 @@ cc_library(
|
|||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
|
@ -924,6 +925,9 @@ xla_cc_test(
|
|||
deps = [
|
||||
":device_description",
|
||||
":semantic_version",
|
||||
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||
"@com_google_absl//absl/status:status_matchers",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@ limitations under the License.
|
|||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_description.pb.h"
|
||||
#include "xla/stream_executor/launch_dim.h"
|
||||
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||
#include "xla/tsl/lib/math/math_util.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
|
|
@ -150,4 +151,33 @@ void CalculateDimensionality(const DeviceDescription &device_description,
|
|||
}
|
||||
}
|
||||
|
||||
GpuComputeCapabilityProto GpuComputeCapability::ToProto() const {
|
||||
GpuComputeCapabilityProto proto;
|
||||
if (IsCuda()) {
|
||||
*proto.mutable_cuda_compute_capability() =
|
||||
cuda_compute_capability()->ToProto();
|
||||
} else {
|
||||
*proto.mutable_rocm_compute_capability() =
|
||||
rocm_compute_capability()->ToProto();
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<GpuComputeCapability> GpuComputeCapability::FromProto(
|
||||
const GpuComputeCapabilityProto& proto) {
|
||||
if (proto.has_cuda_compute_capability()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
CudaComputeCapability cuda_compute_capability,
|
||||
CudaComputeCapability::FromProto(proto.cuda_compute_capability()));
|
||||
return GpuComputeCapability(cuda_compute_capability);
|
||||
}
|
||||
|
||||
if (proto.has_rocm_compute_capability()) {
|
||||
return GpuComputeCapability(
|
||||
RocmComputeCapability::FromProto(proto.rocm_compute_capability()));
|
||||
}
|
||||
|
||||
return absl::InvalidArgumentError(
|
||||
"The serialized GpuComputeCapability has no compute capability set.");
|
||||
}
|
||||
} // namespace stream_executor
|
||||
|
|
|
|||
|
|
@ -78,6 +78,21 @@ class GpuComputeCapability {
|
|||
return rocm_compute_capability()->ToString();
|
||||
}
|
||||
|
||||
GpuComputeCapabilityProto ToProto() const;
|
||||
|
||||
static absl::StatusOr<GpuComputeCapability> FromProto(
|
||||
const GpuComputeCapabilityProto& proto);
|
||||
|
||||
friend bool operator==(const GpuComputeCapability& lhs,
|
||||
const GpuComputeCapability& rhs) {
|
||||
return lhs.compute_capability_ == rhs.compute_capability_;
|
||||
}
|
||||
|
||||
friend bool operator!=(const GpuComputeCapability& lhs,
|
||||
const GpuComputeCapability& rhs) {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
private:
|
||||
std::variant<CudaComputeCapability, RocmComputeCapability>
|
||||
compute_capability_;
|
||||
|
|
|
|||
|
|
@ -24,6 +24,13 @@ message RocmComputeCapabilityProto {
|
|||
string gcn_arch_name = 1;
|
||||
}
|
||||
|
||||
message GpuComputeCapabilityProto {
|
||||
oneof compute_capability {
|
||||
CudaComputeCapabilityProto cuda_compute_capability = 1;
|
||||
RocmComputeCapabilityProto rocm_compute_capability = 2;
|
||||
}
|
||||
}
|
||||
|
||||
message GpuDeviceInfoProto {
|
||||
int32 threads_per_block_limit = 1;
|
||||
int32 threads_per_warp = 2;
|
||||
|
|
|
|||
|
|
@ -16,11 +16,16 @@ limitations under the License.
|
|||
|
||||
#include <string>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/status_matchers.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||
#include "xla/stream_executor/semantic_version.h"
|
||||
|
||||
namespace stream_executor {
|
||||
namespace {
|
||||
using absl_testing::IsOkAndHolds;
|
||||
|
||||
TEST(DeviceDescription, DefaultConstruction) {
|
||||
DeviceDescription desc;
|
||||
|
|
@ -116,5 +121,16 @@ TEST(RocmComputeCapability, Accessors) {
|
|||
EXPECT_TRUE(RocmComputeCapability{"gfx1103"}.has_hipblaslt());
|
||||
}
|
||||
|
||||
TEST(GpuComputeCapability, ProtoConversion) {
|
||||
EXPECT_THAT(
|
||||
GpuComputeCapability::FromProto(
|
||||
GpuComputeCapability(CudaComputeCapability::Volta()).ToProto()),
|
||||
IsOkAndHolds(GpuComputeCapability(CudaComputeCapability::Volta())));
|
||||
EXPECT_THAT(
|
||||
GpuComputeCapability::FromProto(
|
||||
GpuComputeCapability(RocmComputeCapability("gfx900")).ToProto()),
|
||||
IsOkAndHolds(GpuComputeCapability(RocmComputeCapability("gfx900"))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace stream_executor
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user