mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add proto serialization for GpuExecutable
This is adding `GpuExecutuable::ToProto` and `GpuExecutable::FromProto` which allow us to [de]serialize an instance of `GpuExecutable` and later reconstruct it. PiperOrigin-RevId: 826470601
This commit is contained in:
parent
f73a954906
commit
26d0882419
2
third_party/xla/xla/service/gpu/BUILD
vendored
2
third_party/xla/xla/service/gpu/BUILD
vendored
|
|
@ -743,6 +743,7 @@ cc_library(
|
|||
"//xla/backends/gpu/runtime:thunk",
|
||||
"//xla/backends/gpu/runtime:thunk_buffer_debug_pass",
|
||||
"//xla/backends/gpu/runtime:thunk_pass_pipeline",
|
||||
"//xla/backends/gpu/runtime:thunk_proto_deserialization",
|
||||
"//xla/core/collectives:clique_key",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:buffer_assignment",
|
||||
|
|
@ -773,6 +774,7 @@ cc_library(
|
|||
"//xla/tsl/platform:logging",
|
||||
"//xla/tsl/platform:status",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:nullability",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
|
|
|||
139
third_party/xla/xla/service/gpu/gpu_executable.cc
vendored
139
third_party/xla/xla/service/gpu/gpu_executable.cc
vendored
|
|
@ -47,6 +47,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk_buffer_debug_pass.h"
|
||||
#include "xla/backends/gpu/runtime/thunk_pass_pipeline.h"
|
||||
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
|
||||
#include "xla/core/collectives/clique_key.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
|
||||
|
|
@ -60,6 +61,7 @@ limitations under the License.
|
|||
#include "xla/service/gpu/buffer_allocations.h"
|
||||
#include "xla/service/gpu/gpu_constants.h"
|
||||
#include "xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "xla/service/gpu/ir_emission_utils.h"
|
||||
#include "xla/service/gpu/resource_requests.h"
|
||||
#include "xla/service/gpu/stream_executor_util.h"
|
||||
#include "xla/service/hlo_value.h"
|
||||
|
|
@ -1104,8 +1106,8 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
|
|||
return output;
|
||||
}
|
||||
|
||||
OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
|
||||
OutputInfoProto proto;
|
||||
GpuExecutableProto::OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
|
||||
GpuExecutableProto::OutputInfoProto proto;
|
||||
proto.set_allocation_index(allocation_index);
|
||||
proto.set_passthrough(passthrough);
|
||||
|
||||
|
|
@ -1130,7 +1132,7 @@ OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
|
|||
}
|
||||
|
||||
absl::StatusOr<GpuExecutable::OutputInfo> GpuExecutable::OutputInfo::FromProto(
|
||||
const OutputInfoProto& proto) {
|
||||
const GpuExecutableProto::OutputInfoProto& proto) {
|
||||
OutputInfo output_info;
|
||||
output_info.allocation_index = proto.allocation_index();
|
||||
output_info.passthrough = proto.passthrough();
|
||||
|
|
@ -1155,5 +1157,136 @@ absl::StatusOr<GpuExecutable::OutputInfo> GpuExecutable::OutputInfo::FromProto(
|
|||
}
|
||||
return output_info;
|
||||
}
|
||||
|
||||
GpuExecutableProto::ConstantInfoProto GpuExecutable::ConstantInfo::ToProto()
|
||||
const {
|
||||
GpuExecutableProto::ConstantInfoProto proto;
|
||||
proto.set_symbol_name(symbol_name);
|
||||
*proto.mutable_content() = content.ToProto();
|
||||
proto.set_allocation_index(allocation_index);
|
||||
return proto;
|
||||
}
|
||||
|
||||
GpuExecutable::ConstantInfo GpuExecutable::ConstantInfo::FromProto(
|
||||
const GpuExecutableProto::ConstantInfoProto& proto) {
|
||||
return ConstantInfo{
|
||||
/*symbol_name=*/proto.symbol_name(),
|
||||
/*content=*/DenseDataIntermediate::FromProto(proto.content()),
|
||||
/*allocation_index=*/static_cast<int>(proto.allocation_index())};
|
||||
}
|
||||
|
||||
absl::StatusOr<GpuExecutableProto> GpuExecutable::ToProto() const {
|
||||
GpuExecutableProto proto;
|
||||
proto.set_binary(binary_.data(), binary_.size());
|
||||
proto.set_asm_text(text_);
|
||||
proto.mutable_dnn_compiled_graphs()->insert(dnn_compiled_graphs_.cbegin(),
|
||||
dnn_compiled_graphs_.cend());
|
||||
|
||||
*proto.mutable_gpu_compute_capability() = gpu_version_.ToProto();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*proto.mutable_thunk(), thunks_->ToProto());
|
||||
|
||||
proto.set_module_name(module_name_);
|
||||
*proto.mutable_program_shape() = program_shape_.ToProto();
|
||||
|
||||
absl::Span<const BufferAllocation* const> allocations = GetAllocations();
|
||||
proto.mutable_buffer_allocations()->Reserve(allocations.size());
|
||||
for (const auto& allocation : allocations) {
|
||||
proto.mutable_buffer_allocations()->Add(allocation->ToProto());
|
||||
}
|
||||
|
||||
if (hlo_module_ != nullptr) {
|
||||
*proto.mutable_hlo_module() = hlo_module_->ToProtoWithConfig();
|
||||
}
|
||||
|
||||
proto.mutable_output_info_map()->Reserve(output_info_.size());
|
||||
for (const auto& [shape_index, output_info] : output_info_) {
|
||||
auto map_entry = proto.add_output_info_map();
|
||||
*map_entry->mutable_shape_index() = shape_index.ToProto();
|
||||
*map_entry->mutable_output_info() = output_info.ToProto();
|
||||
}
|
||||
|
||||
proto.mutable_constants()->Reserve(constants_.size());
|
||||
for (const auto& constant : constants_) {
|
||||
*proto.add_constants() = constant.ToProto();
|
||||
}
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<GpuExecutable>> GpuExecutable::FromProto(
|
||||
const GpuExecutableProto& proto,
|
||||
const se::DeviceDescription& device_description) {
|
||||
Params params;
|
||||
params.enable_debug_info_manager = false;
|
||||
params.asm_text = proto.asm_text();
|
||||
const std::string& binary = proto.binary();
|
||||
params.binary.assign(binary.begin(), binary.end());
|
||||
params.buffer_assignment = nullptr;
|
||||
if (proto.has_hlo_module()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
params.debug_module,
|
||||
HloModule::CreateFromProtoWithConfig(proto.hlo_module()));
|
||||
}
|
||||
|
||||
params.mlir_allocations.emplace();
|
||||
params.mlir_allocations->reserve(proto.buffer_allocations().size());
|
||||
for (const BufferAllocationProto& allocation_proto :
|
||||
proto.buffer_allocations()) {
|
||||
params.mlir_allocations->push_back(
|
||||
BufferAllocation::FromProto(allocation_proto));
|
||||
}
|
||||
|
||||
for (const auto& [key, value] : proto.dnn_compiled_graphs()) {
|
||||
params.dnn_compiled_graphs.emplace(key, value);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
stream_executor::GpuComputeCapability gpu_compute_capability,
|
||||
stream_executor::GpuComputeCapability::FromProto(
|
||||
proto.gpu_compute_capability()));
|
||||
|
||||
if (gpu_compute_capability != device_description.gpu_compute_capability()) {
|
||||
return absl::InvalidArgumentError(absl::StrFormat(
|
||||
"GPU compute capability of serialized executable doesn't match target "
|
||||
"device capability. (serialized: %s, target: %s)",
|
||||
gpu_compute_capability.ToString(),
|
||||
device_description.gpu_compute_capability().ToString()));
|
||||
}
|
||||
|
||||
params.device_description = device_description;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Thunk> thunk,
|
||||
DeserializeThunkProto(proto.thunk(), params.mlir_allocations.value()));
|
||||
|
||||
if (dynamic_cast<const SequentialThunk*>(thunk.get()) == nullptr) {
|
||||
return absl::InvalidArgumentError(
|
||||
"The top-most serialized thunk in the GPU Executable is not a "
|
||||
"SequentialThunk!");
|
||||
}
|
||||
|
||||
params.executable = unique_ptr_down_cast<SequentialThunk>(std::move(thunk));
|
||||
|
||||
params.constants.reserve(proto.constants().size());
|
||||
for (const auto& constant_proto : proto.constants()) {
|
||||
params.constants.push_back(ConstantInfo::FromProto(constant_proto));
|
||||
}
|
||||
|
||||
params.output_info.reserve(proto.output_info_map().size());
|
||||
for (const auto& output_info_proto : proto.output_info_map()) {
|
||||
ShapeIndex shape_index =
|
||||
ShapeIndex::FromProto(output_info_proto.shape_index());
|
||||
TF_ASSIGN_OR_RETURN(OutputInfo output_info,
|
||||
OutputInfo::FromProto(output_info_proto.output_info()));
|
||||
params.output_info.emplace(std::move(shape_index), std::move(output_info));
|
||||
}
|
||||
|
||||
params.module_name = proto.module_name();
|
||||
TF_ASSIGN_OR_RETURN(params.program_shape,
|
||||
ProgramShape::FromProto(proto.program_shape()));
|
||||
|
||||
return Create(std::move(params));
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
16
third_party/xla/xla/service/gpu/gpu_executable.h
vendored
16
third_party/xla/xla/service/gpu/gpu_executable.h
vendored
|
|
@ -70,6 +70,11 @@ class GpuExecutable : public Executable {
|
|||
std::string symbol_name;
|
||||
DenseDataIntermediate content;
|
||||
int allocation_index = -1;
|
||||
|
||||
GpuExecutableProto::ConstantInfoProto ToProto() const;
|
||||
|
||||
static ConstantInfo FromProto(
|
||||
const GpuExecutableProto::ConstantInfoProto& proto);
|
||||
};
|
||||
|
||||
struct OutputInfo {
|
||||
|
|
@ -83,8 +88,9 @@ class GpuExecutable : public Executable {
|
|||
// would indicate the aliased parameter), and what kind of alias it is.
|
||||
std::optional<HloInputOutputAliasConfig::Alias> alias_config;
|
||||
|
||||
OutputInfoProto ToProto() const;
|
||||
static absl::StatusOr<OutputInfo> FromProto(const OutputInfoProto& proto);
|
||||
GpuExecutableProto::OutputInfoProto ToProto() const;
|
||||
static absl::StatusOr<OutputInfo> FromProto(
|
||||
const GpuExecutableProto::OutputInfoProto& proto);
|
||||
|
||||
friend bool operator==(const OutputInfo& lhs, const OutputInfo& rhs) {
|
||||
return std::tie(lhs.allocation_index, lhs.passthrough,
|
||||
|
|
@ -210,6 +216,12 @@ class GpuExecutable : public Executable {
|
|||
|
||||
absl::Status VerboseAllocationError(absl::Status s);
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<GpuExecutable>> FromProto(
|
||||
const GpuExecutableProto&,
|
||||
const se::DeviceDescription& device_description);
|
||||
|
||||
absl::StatusOr<GpuExecutableProto> ToProto() const;
|
||||
|
||||
private:
|
||||
// Use GpuExecutable::Create() to create an instance.
|
||||
explicit GpuExecutable(Params params,
|
||||
|
|
|
|||
|
|
@ -2,9 +2,50 @@ syntax = "proto3";
|
|||
|
||||
package xla.gpu;
|
||||
|
||||
import "xla/backends/gpu/runtime/thunk.proto";
|
||||
import "xla/service/gpu/ir_emission_utils.proto";
|
||||
import "xla/service/hlo.proto";
|
||||
import "xla/shape_util.proto";
|
||||
import "xla/stream_executor/cuda/cuda_compute_capability.proto";
|
||||
import "xla/stream_executor/device_description.proto";
|
||||
import "xla/xla.proto";
|
||||
import "xla/xla_data.proto";
|
||||
|
||||
message OutputInfoProto {
|
||||
message GpuExecutableProto {
|
||||
// The binary of the executable.
|
||||
//
|
||||
// For CUDA, this is a cubin binary.
|
||||
// For ROCm, this is a hsaco binary.
|
||||
bytes binary = 1;
|
||||
|
||||
// The PTX of the executable. (Only applicable to CUDA)
|
||||
string asm_text = 2;
|
||||
|
||||
// The DNN compiled graphs of the executable.
|
||||
//
|
||||
// The key is the DNN kernel name, and the value is the compiled graph
|
||||
// serialized to JSON. (Only applicable to cuDNN)
|
||||
map<string, string> dnn_compiled_graphs = 3;
|
||||
|
||||
// The target compute capability of the executable.
|
||||
stream_executor.GpuComputeCapabilityProto gpu_compute_capability = 4;
|
||||
|
||||
// The HLO module of the executable - for debugging purposes only.
|
||||
xla.HloModuleProtoWithConfig hlo_module = 5;
|
||||
|
||||
// The thunk tree of the executable.
|
||||
ThunkProto thunk = 6;
|
||||
|
||||
// The name of the HLO module - for debugging purposes only.
|
||||
string module_name = 7;
|
||||
|
||||
// The shape of the program (parameters and result).
|
||||
xla.ProgramShapeProto program_shape = 8;
|
||||
|
||||
// The buffer allocations of the executable.
|
||||
repeated BufferAllocationProto buffer_allocations = 9;
|
||||
|
||||
message OutputInfoProto {
|
||||
// This output is part of the following buffer allocation
|
||||
int64 allocation_index = 1;
|
||||
|
||||
|
|
@ -13,4 +54,27 @@ message OutputInfoProto {
|
|||
|
||||
// Describes whether and how this output aliases with an input parameter
|
||||
optional xla.HloInputOutputAliasProto.AliasEntryProto alias_config = 3;
|
||||
}
|
||||
|
||||
message OutputInfoMapEntry {
|
||||
xla.ShapeIndexProto shape_index = 1;
|
||||
OutputInfoProto output_info = 2;
|
||||
}
|
||||
|
||||
// Map from output shape index to output info.
|
||||
repeated OutputInfoMapEntry output_info_map = 10;
|
||||
|
||||
message ConstantInfoProto {
|
||||
// The name of the constant in the HLO module.
|
||||
string symbol_name = 1;
|
||||
|
||||
// The content of the constant - this can be large.
|
||||
DenseDataIntermediateProto content = 2;
|
||||
|
||||
// The index of the buffer allocation for this constant.
|
||||
int64 allocation_index = 3;
|
||||
}
|
||||
|
||||
// The constants used by the executable.
|
||||
repeated ConstantInfoProto constants = 11;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,8 +62,11 @@ namespace xla::gpu {
|
|||
namespace {
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Pair;
|
||||
using ::testing::Pointee;
|
||||
using ::testing::Property;
|
||||
using ::testing::SizeIs;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
using ::tsl::proto_testing::EqualsProto;
|
||||
|
||||
TEST(GpuExecutableTest, OuputInfoToAndFromProto) {
|
||||
|
|
@ -433,5 +436,56 @@ TEST(GpuExecutableTest, DumpsMetadataListProto) {
|
|||
)pb"));
|
||||
}
|
||||
|
||||
TEST(GpuExecutableTest, ProtoConversion) {
|
||||
se::DeviceDescription device_description;
|
||||
device_description.set_gpu_compute_capability(
|
||||
se::GpuComputeCapability{se::CudaComputeCapability::Volta()});
|
||||
device_description.set_driver_version({12, 3, 0});
|
||||
device_description.set_runtime_version({12, 3, 0});
|
||||
|
||||
Thunk::ThunkInfo thunk_info;
|
||||
thunk_info.thunk_id = 123;
|
||||
|
||||
ThunkSequence thunk_sequence;
|
||||
thunk_sequence.push_back(std::make_unique<KernelThunk>(
|
||||
thunk_info,
|
||||
/*kernel_name=*/"test_kernel", emitters::KernelArguments({}),
|
||||
LaunchDimensions(),
|
||||
/*cluster_dim=*/std::nullopt,
|
||||
/*shmem_bytes=*/0, se::gpu::TmaMetadata()));
|
||||
|
||||
GpuExecutable::Params params;
|
||||
params.asm_text = "test_asm_text";
|
||||
params.binary = {1, 2, 3};
|
||||
params.dnn_compiled_graphs = {{"test_dnn_compiled_graph", "test_json"}};
|
||||
|
||||
thunk_info.thunk_id = 456;
|
||||
params.executable =
|
||||
std::make_unique<SequentialThunk>(thunk_info, std::move(thunk_sequence));
|
||||
params.device_description = device_description;
|
||||
|
||||
params.module_name = "test_module";
|
||||
params.enable_debug_info_manager = false;
|
||||
params.mlir_allocations = {BufferAllocation(0, 1024, 0)};
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GpuExecutable> reference_executable,
|
||||
GpuExecutable::Create(std::move(params)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(GpuExecutableProto proto,
|
||||
reference_executable->ToProto());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GpuExecutable> reconstructed_executable,
|
||||
GpuExecutable::FromProto(proto, device_description));
|
||||
EXPECT_THAT(reconstructed_executable->text(), "test_asm_text");
|
||||
EXPECT_THAT(reconstructed_executable->binary(), ElementsAre(1, 2, 3));
|
||||
EXPECT_THAT(
|
||||
reconstructed_executable->dnn_compiled_graphs(),
|
||||
UnorderedElementsAre(Pair("test_dnn_compiled_graph", "test_json")));
|
||||
EXPECT_THAT(reconstructed_executable->GetThunk().thunks(),
|
||||
ElementsAre(Pointee(Property(&Thunk::kind, Thunk::kKernel))));
|
||||
EXPECT_THAT(reconstructed_executable->GetAllocations(),
|
||||
ElementsAre(Pointee(Property(&BufferAllocation::size, 1024))));
|
||||
EXPECT_THAT(reconstructed_executable->name(), "test_module");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla::gpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user