mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add proto serialization for GpuComputeCapability
PiperOrigin-RevId: 825657032
This commit is contained in:
parent
9cbe7bd184
commit
757f0ac980
4
third_party/xla/xla/stream_executor/BUILD
vendored
4
third_party/xla/xla/stream_executor/BUILD
vendored
|
|
@ -68,6 +68,7 @@ cc_library(
|
||||||
"//xla/tsl/platform:statusor",
|
"//xla/tsl/platform:statusor",
|
||||||
"@com_google_absl//absl/log",
|
"@com_google_absl//absl/log",
|
||||||
"@com_google_absl//absl/log:check",
|
"@com_google_absl//absl/log:check",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
"@com_google_absl//absl/status:statusor",
|
"@com_google_absl//absl/status:statusor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -924,6 +925,9 @@ xla_cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":device_description",
|
":device_description",
|
||||||
":semantic_version",
|
":semantic_version",
|
||||||
|
"//xla/stream_executor/cuda:cuda_compute_capability",
|
||||||
|
"//xla/stream_executor/rocm:rocm_compute_capability",
|
||||||
|
"@com_google_absl//absl/status:status_matchers",
|
||||||
"@com_google_googletest//:gtest_main",
|
"@com_google_googletest//:gtest_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,14 @@ limitations under the License.
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <variant>
|
|
||||||
|
|
||||||
#include "absl/log/check.h"
|
#include "absl/log/check.h"
|
||||||
#include "absl/log/log.h"
|
#include "absl/log/log.h"
|
||||||
|
#include "absl/status/status.h"
|
||||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||||
#include "xla/stream_executor/device_description.pb.h"
|
#include "xla/stream_executor/device_description.pb.h"
|
||||||
#include "xla/stream_executor/launch_dim.h"
|
#include "xla/stream_executor/launch_dim.h"
|
||||||
|
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||||
#include "xla/tsl/lib/math/math_util.h"
|
#include "xla/tsl/lib/math/math_util.h"
|
||||||
#include "xla/tsl/platform/statusor.h"
|
#include "xla/tsl/platform/statusor.h"
|
||||||
|
|
||||||
|
|
@ -150,4 +151,33 @@ void CalculateDimensionality(const DeviceDescription &device_description,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GpuComputeCapabilityProto GpuComputeCapability::ToProto() const {
|
||||||
|
GpuComputeCapabilityProto proto;
|
||||||
|
if (IsCuda()) {
|
||||||
|
*proto.mutable_cuda_compute_capability() =
|
||||||
|
cuda_compute_capability()->ToProto();
|
||||||
|
} else {
|
||||||
|
*proto.mutable_rocm_compute_capability() =
|
||||||
|
rocm_compute_capability()->ToProto();
|
||||||
|
}
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::StatusOr<GpuComputeCapability> GpuComputeCapability::FromProto(
|
||||||
|
const GpuComputeCapabilityProto& proto) {
|
||||||
|
if (proto.has_cuda_compute_capability()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
CudaComputeCapability cuda_compute_capability,
|
||||||
|
CudaComputeCapability::FromProto(proto.cuda_compute_capability()));
|
||||||
|
return GpuComputeCapability(cuda_compute_capability);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (proto.has_rocm_compute_capability()) {
|
||||||
|
return GpuComputeCapability(
|
||||||
|
RocmComputeCapability::FromProto(proto.rocm_compute_capability()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return absl::InvalidArgumentError(
|
||||||
|
"The serialized GpuComputeCapability has no compute capability set.");
|
||||||
|
}
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,21 @@ class GpuComputeCapability {
|
||||||
return rocm_compute_capability()->ToString();
|
return rocm_compute_capability()->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GpuComputeCapabilityProto ToProto() const;
|
||||||
|
|
||||||
|
static absl::StatusOr<GpuComputeCapability> FromProto(
|
||||||
|
const GpuComputeCapabilityProto& proto);
|
||||||
|
|
||||||
|
friend bool operator==(const GpuComputeCapability& lhs,
|
||||||
|
const GpuComputeCapability& rhs) {
|
||||||
|
return lhs.compute_capability_ == rhs.compute_capability_;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend bool operator!=(const GpuComputeCapability& lhs,
|
||||||
|
const GpuComputeCapability& rhs) {
|
||||||
|
return !(lhs == rhs);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::variant<CudaComputeCapability, RocmComputeCapability>
|
std::variant<CudaComputeCapability, RocmComputeCapability>
|
||||||
compute_capability_;
|
compute_capability_;
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,13 @@ message RocmComputeCapabilityProto {
|
||||||
string gcn_arch_name = 1;
|
string gcn_arch_name = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message GpuComputeCapabilityProto {
|
||||||
|
oneof compute_capability {
|
||||||
|
CudaComputeCapabilityProto cuda_compute_capability = 1;
|
||||||
|
RocmComputeCapabilityProto rocm_compute_capability = 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
message GpuDeviceInfoProto {
|
message GpuDeviceInfoProto {
|
||||||
int32 threads_per_block_limit = 1;
|
int32 threads_per_block_limit = 1;
|
||||||
int32 threads_per_warp = 2;
|
int32 threads_per_warp = 2;
|
||||||
|
|
|
||||||
|
|
@ -16,11 +16,16 @@ limitations under the License.
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include <gmock/gmock.h>
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
#include "absl/status/status_matchers.h"
|
||||||
|
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||||
|
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
|
||||||
#include "xla/stream_executor/semantic_version.h"
|
#include "xla/stream_executor/semantic_version.h"
|
||||||
|
|
||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace {
|
namespace {
|
||||||
|
using absl_testing::IsOkAndHolds;
|
||||||
|
|
||||||
TEST(DeviceDescription, DefaultConstruction) {
|
TEST(DeviceDescription, DefaultConstruction) {
|
||||||
DeviceDescription desc;
|
DeviceDescription desc;
|
||||||
|
|
@ -116,5 +121,16 @@ TEST(RocmComputeCapability, Accessors) {
|
||||||
EXPECT_TRUE(RocmComputeCapability{"gfx1103"}.has_hipblaslt());
|
EXPECT_TRUE(RocmComputeCapability{"gfx1103"}.has_hipblaslt());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GpuComputeCapability, ProtoConversion) {
|
||||||
|
EXPECT_THAT(
|
||||||
|
GpuComputeCapability::FromProto(
|
||||||
|
GpuComputeCapability(CudaComputeCapability::Volta()).ToProto()),
|
||||||
|
IsOkAndHolds(GpuComputeCapability(CudaComputeCapability::Volta())));
|
||||||
|
EXPECT_THAT(
|
||||||
|
GpuComputeCapability::FromProto(
|
||||||
|
GpuComputeCapability(RocmComputeCapability("gfx900")).ToProto()),
|
||||||
|
IsOkAndHolds(GpuComputeCapability(RocmComputeCapability("gfx900"))));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace stream_executor
|
} // namespace stream_executor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user