[xla:cpu] Move buffer allocation info encoding to tf2xla

PiperOrigin-RevId: 825732652
This commit is contained in:
Eugene Zhulenev 2025-10-29 15:26:44 -07:00 committed by TensorFlower Gardener
parent 756a72760a
commit 0f559dec93
9 changed files with 131 additions and 87 deletions

View File

@ -96,6 +96,7 @@ cc_library(
":thunk_proto_execution_deserializer",
"//tensorflow/compiler/tf2xla",
"//tensorflow/compiler/tf2xla:allocator",
"//tensorflow/compiler/tf2xla:encoded_buffer_allocation_info",
"//tensorflow/compiler/tf2xla:mlir_tf2xla", # fixdeps: keep
"//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
"//tensorflow/compiler/tf2xla:tf2xla_util",

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
#include "tensorflow/compiler/aot/thunk_proto_execution_deserializer.h"
#include "tensorflow/compiler/tf2xla/allocator.h"
#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "xla/backends/cpu/buffer_allocation_info.h"

View File

@ -138,6 +138,25 @@ cc_library(
],
)
cc_library(
name = "encoded_buffer_allocation_info",
hdrs = ["encoded_buffer_allocation_info.h"],
visibility = [":friends"],
deps = [
"@local_xla//xla/backends/cpu:buffer_allocation_info",
],
)
tf_cc_test(
name = "encoded_buffer_allocation_info_test",
srcs = ["encoded_buffer_allocation_info_test.cc"],
deps = [
":encoded_buffer_allocation_info",
"@com_google_googletest//:gtest_main",
"@local_xla//xla/backends/cpu:buffer_allocation_info",
],
)
cc_library(
name = "tf2xla",
srcs = ["tf2xla.cc"],
@ -218,6 +237,7 @@ filegroup(
name = "xla_compiled_cpu_runtime_hdrs",
srcs = [
"allocator.h",
"encoded_buffer_allocation_info.h",
"xla_compiled_cpu_function.h",
"//tensorflow/core/kernels:xla_cpu_runtime_hdrs",
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
@ -437,6 +457,7 @@ cc_library(
":allocator",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/types:span",
":encoded_buffer_allocation_info",
"@local_xla//xla/service:custom_call_status_internal",
"@local_xla//xla/backends/cpu/runtime:rng_state_lib",
"@local_xla//xla/backends/cpu:alignment",
@ -502,6 +523,7 @@ cc_library(
hdrs = ["xla_jit_compiled_cpu_function.h"],
visibility = ["//visibility:public"],
deps = [
":encoded_buffer_allocation_info",
":tf2xla",
":tf2xla_proto_cc",
":xla_compiled_cpu_function",

View File

@ -0,0 +1,99 @@
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_
#define TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_
#include <cstdint>
#include "xla/backends/cpu/buffer_allocation_info.h"
namespace xla {
namespace cpu {
// Encoded version of `BufferAllocationInfo`, which can be used to reconstruct
// the `BufferAllocationInfo` later. It's used in the AOT compiler, to
// represent buffer allocation info as a lightweight struct.
struct EncodedBufferAllocationInfo {
EncodedBufferAllocationInfo(uint64_t packed_kind_and_size,
uint32_t entry_param_number,
uint32_t result_number)
: packed_kind_and_size(packed_kind_and_size),
entry_param_number(entry_param_number),
result_number(result_number) {}
// Encodes BufferAllocationInfo into the struct that can be used to
// reconstruct the BufferAllocationInfo later using the constructor. We need
// this because we use BufferAllocationInfo in places where using protocol
// buffers would negatively impact binary size.
explicit EncodedBufferAllocationInfo(
const BufferAllocationInfo& buffer_info) {
packed_kind_and_size = Pack(buffer_info.kind(), buffer_info.size());
entry_param_number = buffer_info.is_entry_parameter()
? buffer_info.entry_parameter_number()
: -1;
result_number = buffer_info.is_result() ? buffer_info.result_number() : -1;
}
explicit operator BufferAllocationInfo() const {
auto kind = UnpackKind(packed_kind_and_size);
auto size = UnpackSize(packed_kind_and_size);
int32_t entry_param_number = static_cast<int32_t>(this->entry_param_number);
int32_t result_number = static_cast<int32_t>(this->result_number);
switch (kind) {
case BufferAllocationInfo::Kind::kConstant:
return BufferAllocationInfo::Constant(size);
case BufferAllocationInfo::Kind::kTemp:
return BufferAllocationInfo::Temp(size);
case BufferAllocationInfo::Kind::kParameter:
if (entry_param_number >= 0 && result_number >= 0) {
return BufferAllocationInfo::InOutParameter(size, entry_param_number,
result_number);
}
if (entry_param_number >= 0) {
return BufferAllocationInfo::EntryParameter(size, entry_param_number);
}
return BufferAllocationInfo::Result(size, result_number);
case BufferAllocationInfo::Kind::kThreadLocal:
return BufferAllocationInfo::ThreadLocal(size);
}
}
static uint64_t Pack(BufferAllocationInfo::Kind kind, uint64_t size) {
return (static_cast<uint64_t>(size) << 2) | static_cast<uint64_t>(kind);
}
static constexpr BufferAllocationInfo::Kind UnpackKind(uint64_t packed) {
return static_cast<BufferAllocationInfo::Kind>((packed << 62) >> 62);
}
static constexpr uint64_t UnpackSize(uint64_t packed) { return packed >> 2; }
uint64_t packed_kind_and_size = 0;
uint32_t entry_param_number = -1;
uint32_t result_number = -1;
};
} // namespace cpu
// TODO(ezhulenev): This is a temporary hack to keep `tfcompile` code working.
namespace cpu_function_runtime {
using BufferInfo = ::xla::cpu::BufferAllocationInfo;
using EncodedBufferInfo = ::xla::cpu::EncodedBufferAllocationInfo;
} // namespace cpu_function_runtime
} // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_ENCODED_BUFFER_ALLOCATION_INFO_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2025 The OpenXLA Authors.
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/cpu/buffer_allocation_info.h"
#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h"
#include <gtest/gtest.h>
#include "xla/backends/cpu/buffer_allocation_info.h"
namespace xla::cpu {
namespace {
TEST(BufferAllocationInfoTest, RoundTrip) {
TEST(EncodedBufferAllocationInfoTest, RoundTrip) {
auto round_trip = [](const BufferAllocationInfo& buffer_info) {
EncodedBufferAllocationInfo encoded(buffer_info);
BufferAllocationInfo round_trip(encoded);

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h"
#include "xla/backends/cpu/alignment.h"
#include "xla/backends/cpu/buffer_allocation_info.h"
#include "xla/backends/cpu/runtime/rng_state_lib.h"

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
#include "tensorflow/compiler/tf2xla/encoded_buffer_allocation_info.h"
#include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h"
#include "xla/backends/cpu/buffer_allocation_info.h"

View File

@ -60,15 +60,6 @@ cc_library(
],
)
xla_cc_test(
name = "buffer_allocation_info_test",
srcs = ["buffer_allocation_info_test.cc"],
deps = [
":buffer_allocation_info",
"@com_google_googletest//:gtest_main",
],
)
onednn_graph_cc_library(
name = "onednn_emitter",
srcs = ["onednn_emitter.cc"],

View File

@ -19,8 +19,7 @@ limitations under the License.
#include <cassert>
#include <cstdint>
namespace xla {
namespace cpu {
namespace xla::cpu {
// `BufferAllocationInfo` stores information about buffer allocations required
// by an XLA:CPU executable at run time. It corresponds to a `BufferAllocation`
@ -117,78 +116,6 @@ class BufferAllocationInfo {
int32_t result_number_ = -1;
};
// Encoded version of `BufferAllocationInfo`, which can be used to reconstruct
// the `BufferAllocationInfo` later. It's used in the AOT compiler, to
// represent buffer allocation info as a lightweight struct.
struct EncodedBufferAllocationInfo {
EncodedBufferAllocationInfo(uint64_t packed_kind_and_size,
uint32_t entry_param_number,
uint32_t result_number)
: packed_kind_and_size(packed_kind_and_size),
entry_param_number(entry_param_number),
result_number(result_number) {}
// Encodes BufferAllocationInfo into the struct that can be used to
// reconstruct the BufferAllocationInfo later using the constructor. We need
// this because we use BufferAllocationInfo in places where using protocol
// buffers would negatively impact binary size.
explicit EncodedBufferAllocationInfo(
const BufferAllocationInfo& buffer_info) {
packed_kind_and_size = Pack(buffer_info.kind(), buffer_info.size());
entry_param_number = buffer_info.is_entry_parameter()
? buffer_info.entry_parameter_number()
: -1;
result_number = buffer_info.is_result() ? buffer_info.result_number() : -1;
}
explicit operator BufferAllocationInfo() const {
auto kind = UnpackKind(packed_kind_and_size);
auto size = UnpackSize(packed_kind_and_size);
int32_t entry_param_number = static_cast<int32_t>(this->entry_param_number);
int32_t result_number = static_cast<int32_t>(this->result_number);
switch (kind) {
case BufferAllocationInfo::Kind::kConstant:
return BufferAllocationInfo::Constant(size);
case BufferAllocationInfo::Kind::kTemp:
return BufferAllocationInfo::Temp(size);
case BufferAllocationInfo::Kind::kParameter:
if (entry_param_number >= 0 && result_number >= 0) {
return BufferAllocationInfo::InOutParameter(size, entry_param_number,
result_number);
}
if (entry_param_number >= 0) {
return BufferAllocationInfo::EntryParameter(size, entry_param_number);
}
return BufferAllocationInfo::Result(size, result_number);
case BufferAllocationInfo::Kind::kThreadLocal:
return BufferAllocationInfo::ThreadLocal(size);
}
}
static uint64_t Pack(BufferAllocationInfo::Kind kind, uint64_t size) {
return (static_cast<uint64_t>(size) << 2) | static_cast<uint64_t>(kind);
}
static constexpr BufferAllocationInfo::Kind UnpackKind(uint64_t packed) {
return static_cast<BufferAllocationInfo::Kind>((packed << 62) >> 62);
}
static constexpr uint64_t UnpackSize(uint64_t packed) { return packed >> 2; }
uint64_t packed_kind_and_size = 0;
uint32_t entry_param_number = -1;
uint32_t result_number = -1;
};
} // namespace cpu
// TODO(ezhulenev): This is a temporary hack to keep `tfcompile` code working.
namespace cpu_function_runtime {
using BufferInfo = ::xla::cpu::BufferAllocationInfo;
using EncodedBufferInfo = ::xla::cpu::EncodedBufferAllocationInfo;
} // namespace cpu_function_runtime
} // namespace xla
} // namespace xla::cpu
#endif // XLA_BACKENDS_CPU_BUFFER_ALLOCATION_INFO_H_