Add proto serialization for GpuExecutable

This is adding `GpuExecutuable::ToProto` and `GpuExecutable::FromProto` which allow us to [de]serialize an instance of `GpuExecutable` and later reconstruct it.

PiperOrigin-RevId: 826470601
This commit is contained in:
Henning Becker 2025-10-31 06:51:54 -07:00 committed by TensorFlower Gardener
parent f73a954906
commit 26d0882419
5 changed files with 277 additions and 12 deletions

View File

@ -743,6 +743,7 @@ cc_library(
"//xla/backends/gpu/runtime:thunk",
"//xla/backends/gpu/runtime:thunk_buffer_debug_pass",
"//xla/backends/gpu/runtime:thunk_pass_pipeline",
"//xla/backends/gpu/runtime:thunk_proto_deserialization",
"//xla/core/collectives:clique_key",
"//xla/hlo/ir:hlo",
"//xla/service:buffer_assignment",
@ -773,6 +774,7 @@ cc_library(
"//xla/tsl/platform:logging",
"//xla/tsl/platform:status",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk_buffer_debug_pass.h"
#include "xla/backends/gpu/runtime/thunk_pass_pipeline.h"
#include "xla/backends/gpu/runtime/thunk_proto_deserialization.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/executable_run_options.h"
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
@ -60,6 +61,7 @@ limitations under the License.
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/gpu_constants.h"
#include "xla/service/gpu/gpu_executable_run_options.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/resource_requests.h"
#include "xla/service/gpu/stream_executor_util.h"
#include "xla/service/hlo_value.h"
@ -1104,8 +1106,8 @@ GetOutputInfo(const HloModule& hlo_module, const BufferAssignment& assignment) {
return output;
}
OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
OutputInfoProto proto;
GpuExecutableProto::OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
GpuExecutableProto::OutputInfoProto proto;
proto.set_allocation_index(allocation_index);
proto.set_passthrough(passthrough);
@ -1130,7 +1132,7 @@ OutputInfoProto GpuExecutable::OutputInfo::ToProto() const {
}
absl::StatusOr<GpuExecutable::OutputInfo> GpuExecutable::OutputInfo::FromProto(
const OutputInfoProto& proto) {
const GpuExecutableProto::OutputInfoProto& proto) {
OutputInfo output_info;
output_info.allocation_index = proto.allocation_index();
output_info.passthrough = proto.passthrough();
@ -1155,5 +1157,136 @@ absl::StatusOr<GpuExecutable::OutputInfo> GpuExecutable::OutputInfo::FromProto(
}
return output_info;
}
GpuExecutableProto::ConstantInfoProto GpuExecutable::ConstantInfo::ToProto()
const {
GpuExecutableProto::ConstantInfoProto proto;
proto.set_symbol_name(symbol_name);
*proto.mutable_content() = content.ToProto();
proto.set_allocation_index(allocation_index);
return proto;
}
GpuExecutable::ConstantInfo GpuExecutable::ConstantInfo::FromProto(
const GpuExecutableProto::ConstantInfoProto& proto) {
return ConstantInfo{
/*symbol_name=*/proto.symbol_name(),
/*content=*/DenseDataIntermediate::FromProto(proto.content()),
/*allocation_index=*/static_cast<int>(proto.allocation_index())};
}
absl::StatusOr<GpuExecutableProto> GpuExecutable::ToProto() const {
GpuExecutableProto proto;
proto.set_binary(binary_.data(), binary_.size());
proto.set_asm_text(text_);
proto.mutable_dnn_compiled_graphs()->insert(dnn_compiled_graphs_.cbegin(),
dnn_compiled_graphs_.cend());
*proto.mutable_gpu_compute_capability() = gpu_version_.ToProto();
TF_ASSIGN_OR_RETURN(*proto.mutable_thunk(), thunks_->ToProto());
proto.set_module_name(module_name_);
*proto.mutable_program_shape() = program_shape_.ToProto();
absl::Span<const BufferAllocation* const> allocations = GetAllocations();
proto.mutable_buffer_allocations()->Reserve(allocations.size());
for (const auto& allocation : allocations) {
proto.mutable_buffer_allocations()->Add(allocation->ToProto());
}
if (hlo_module_ != nullptr) {
*proto.mutable_hlo_module() = hlo_module_->ToProtoWithConfig();
}
proto.mutable_output_info_map()->Reserve(output_info_.size());
for (const auto& [shape_index, output_info] : output_info_) {
auto map_entry = proto.add_output_info_map();
*map_entry->mutable_shape_index() = shape_index.ToProto();
*map_entry->mutable_output_info() = output_info.ToProto();
}
proto.mutable_constants()->Reserve(constants_.size());
for (const auto& constant : constants_) {
*proto.add_constants() = constant.ToProto();
}
return proto;
}
absl::StatusOr<std::unique_ptr<GpuExecutable>> GpuExecutable::FromProto(
const GpuExecutableProto& proto,
const se::DeviceDescription& device_description) {
Params params;
params.enable_debug_info_manager = false;
params.asm_text = proto.asm_text();
const std::string& binary = proto.binary();
params.binary.assign(binary.begin(), binary.end());
params.buffer_assignment = nullptr;
if (proto.has_hlo_module()) {
TF_ASSIGN_OR_RETURN(
params.debug_module,
HloModule::CreateFromProtoWithConfig(proto.hlo_module()));
}
params.mlir_allocations.emplace();
params.mlir_allocations->reserve(proto.buffer_allocations().size());
for (const BufferAllocationProto& allocation_proto :
proto.buffer_allocations()) {
params.mlir_allocations->push_back(
BufferAllocation::FromProto(allocation_proto));
}
for (const auto& [key, value] : proto.dnn_compiled_graphs()) {
params.dnn_compiled_graphs.emplace(key, value);
}
TF_ASSIGN_OR_RETURN(
stream_executor::GpuComputeCapability gpu_compute_capability,
stream_executor::GpuComputeCapability::FromProto(
proto.gpu_compute_capability()));
if (gpu_compute_capability != device_description.gpu_compute_capability()) {
return absl::InvalidArgumentError(absl::StrFormat(
"GPU compute capability of serialized executable doesn't match target "
"device capability. (serialized: %s, target: %s)",
gpu_compute_capability.ToString(),
device_description.gpu_compute_capability().ToString()));
}
params.device_description = device_description;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Thunk> thunk,
DeserializeThunkProto(proto.thunk(), params.mlir_allocations.value()));
if (dynamic_cast<const SequentialThunk*>(thunk.get()) == nullptr) {
return absl::InvalidArgumentError(
"The top-most serialized thunk in the GPU Executable is not a "
"SequentialThunk!");
}
params.executable = unique_ptr_down_cast<SequentialThunk>(std::move(thunk));
params.constants.reserve(proto.constants().size());
for (const auto& constant_proto : proto.constants()) {
params.constants.push_back(ConstantInfo::FromProto(constant_proto));
}
params.output_info.reserve(proto.output_info_map().size());
for (const auto& output_info_proto : proto.output_info_map()) {
ShapeIndex shape_index =
ShapeIndex::FromProto(output_info_proto.shape_index());
TF_ASSIGN_OR_RETURN(OutputInfo output_info,
OutputInfo::FromProto(output_info_proto.output_info()));
params.output_info.emplace(std::move(shape_index), std::move(output_info));
}
params.module_name = proto.module_name();
TF_ASSIGN_OR_RETURN(params.program_shape,
ProgramShape::FromProto(proto.program_shape()));
return Create(std::move(params));
}
} // namespace gpu
} // namespace xla

View File

@ -70,6 +70,11 @@ class GpuExecutable : public Executable {
std::string symbol_name;
DenseDataIntermediate content;
int allocation_index = -1;
GpuExecutableProto::ConstantInfoProto ToProto() const;
static ConstantInfo FromProto(
const GpuExecutableProto::ConstantInfoProto& proto);
};
struct OutputInfo {
@ -83,8 +88,9 @@ class GpuExecutable : public Executable {
// would indicate the aliased parameter), and what kind of alias it is.
std::optional<HloInputOutputAliasConfig::Alias> alias_config;
OutputInfoProto ToProto() const;
static absl::StatusOr<OutputInfo> FromProto(const OutputInfoProto& proto);
GpuExecutableProto::OutputInfoProto ToProto() const;
static absl::StatusOr<OutputInfo> FromProto(
const GpuExecutableProto::OutputInfoProto& proto);
friend bool operator==(const OutputInfo& lhs, const OutputInfo& rhs) {
return std::tie(lhs.allocation_index, lhs.passthrough,
@ -210,6 +216,12 @@ class GpuExecutable : public Executable {
absl::Status VerboseAllocationError(absl::Status s);
static absl::StatusOr<std::unique_ptr<GpuExecutable>> FromProto(
const GpuExecutableProto&,
const se::DeviceDescription& device_description);
absl::StatusOr<GpuExecutableProto> ToProto() const;
private:
// Use GpuExecutable::Create() to create an instance.
explicit GpuExecutable(Params params,

View File

@ -2,15 +2,79 @@ syntax = "proto3";
package xla.gpu;
import "xla/backends/gpu/runtime/thunk.proto";
import "xla/service/gpu/ir_emission_utils.proto";
import "xla/service/hlo.proto";
import "xla/shape_util.proto";
import "xla/stream_executor/cuda/cuda_compute_capability.proto";
import "xla/stream_executor/device_description.proto";
import "xla/xla.proto";
import "xla/xla_data.proto";
message OutputInfoProto {
// This output is part of the following buffer allocation
int64 allocation_index = 1;
message GpuExecutableProto {
// The binary of the executable.
//
// For CUDA, this is a cubin binary.
// For ROCm, this is a hsaco binary.
bytes binary = 1;
// True when this output is passed through from an input parameter
bool passthrough = 2;
// The PTX of the executable. (Only applicable to CUDA)
string asm_text = 2;
// Describes whether and how this output aliases with an input parameter
optional xla.HloInputOutputAliasProto.AliasEntryProto alias_config = 3;
// The DNN compiled graphs of the executable.
//
// The key is the DNN kernel name, and the value is the compiled graph
// serialized to JSON. (Only applicable to cuDNN)
map<string, string> dnn_compiled_graphs = 3;
// The target compute capability of the executable.
stream_executor.GpuComputeCapabilityProto gpu_compute_capability = 4;
// The HLO module of the executable - for debugging purposes only.
xla.HloModuleProtoWithConfig hlo_module = 5;
// The thunk tree of the executable.
ThunkProto thunk = 6;
// The name of the HLO module - for debugging purposes only.
string module_name = 7;
// The shape of the program (parameters and result).
xla.ProgramShapeProto program_shape = 8;
// The buffer allocations of the executable.
repeated BufferAllocationProto buffer_allocations = 9;
message OutputInfoProto {
// This output is part of the following buffer allocation
int64 allocation_index = 1;
// True when this output is passed through from an input parameter
bool passthrough = 2;
// Describes whether and how this output aliases with an input parameter
optional xla.HloInputOutputAliasProto.AliasEntryProto alias_config = 3;
}
message OutputInfoMapEntry {
xla.ShapeIndexProto shape_index = 1;
OutputInfoProto output_info = 2;
}
// Map from output shape index to output info.
repeated OutputInfoMapEntry output_info_map = 10;
message ConstantInfoProto {
// The name of the constant in the HLO module.
string symbol_name = 1;
// The content of the constant - this can be large.
DenseDataIntermediateProto content = 2;
// The index of the buffer allocation for this constant.
int64 allocation_index = 3;
}
// The constants used by the executable.
repeated ConstantInfoProto constants = 11;
}

View File

@ -62,8 +62,11 @@ namespace xla::gpu {
namespace {
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Pair;
using ::testing::Pointee;
using ::testing::Property;
using ::testing::SizeIs;
using ::testing::UnorderedElementsAre;
using ::tsl::proto_testing::EqualsProto;
TEST(GpuExecutableTest, OuputInfoToAndFromProto) {
@ -433,5 +436,56 @@ TEST(GpuExecutableTest, DumpsMetadataListProto) {
)pb"));
}
TEST(GpuExecutableTest, ProtoConversion) {
se::DeviceDescription device_description;
device_description.set_gpu_compute_capability(
se::GpuComputeCapability{se::CudaComputeCapability::Volta()});
device_description.set_driver_version({12, 3, 0});
device_description.set_runtime_version({12, 3, 0});
Thunk::ThunkInfo thunk_info;
thunk_info.thunk_id = 123;
ThunkSequence thunk_sequence;
thunk_sequence.push_back(std::make_unique<KernelThunk>(
thunk_info,
/*kernel_name=*/"test_kernel", emitters::KernelArguments({}),
LaunchDimensions(),
/*cluster_dim=*/std::nullopt,
/*shmem_bytes=*/0, se::gpu::TmaMetadata()));
GpuExecutable::Params params;
params.asm_text = "test_asm_text";
params.binary = {1, 2, 3};
params.dnn_compiled_graphs = {{"test_dnn_compiled_graph", "test_json"}};
thunk_info.thunk_id = 456;
params.executable =
std::make_unique<SequentialThunk>(thunk_info, std::move(thunk_sequence));
params.device_description = device_description;
params.module_name = "test_module";
params.enable_debug_info_manager = false;
params.mlir_allocations = {BufferAllocation(0, 1024, 0)};
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GpuExecutable> reference_executable,
GpuExecutable::Create(std::move(params)));
TF_ASSERT_OK_AND_ASSIGN(GpuExecutableProto proto,
reference_executable->ToProto());
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GpuExecutable> reconstructed_executable,
GpuExecutable::FromProto(proto, device_description));
EXPECT_THAT(reconstructed_executable->text(), "test_asm_text");
EXPECT_THAT(reconstructed_executable->binary(), ElementsAre(1, 2, 3));
EXPECT_THAT(
reconstructed_executable->dnn_compiled_graphs(),
UnorderedElementsAre(Pair("test_dnn_compiled_graph", "test_json")));
EXPECT_THAT(reconstructed_executable->GetThunk().thunks(),
ElementsAre(Pointee(Property(&Thunk::kind, Thunk::kKernel))));
EXPECT_THAT(reconstructed_executable->GetAllocations(),
ElementsAre(Pointee(Property(&BufferAllocation::size, 1024))));
EXPECT_THAT(reconstructed_executable->name(), "test_module");
}
} // namespace
} // namespace xla::gpu