Add (de)serialization for FftThunk

This one is a pretty direct mapping from the struct to the proto.

PiperOrigin-RevId: 819214943
This commit is contained in:
Eusebio Durán Montaña 2025-10-14 07:39:02 -07:00 committed by TensorFlower Gardener
parent 0064d2d1bb
commit d555ed2c74
6 changed files with 189 additions and 0 deletions

View File

@ -744,12 +744,14 @@ cc_library(
hdrs = ["fft_thunk.h"],
deps = [
":thunk",
":thunk_proto_cc",
"//xla:shape_util",
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/service:buffer_assignment",
"//xla/service:buffer_assignment_proto_cc",
"//xla/stream_executor:blas",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:device_memory_allocator",
@ -768,6 +770,22 @@ cc_library(
],
)
xla_cc_test(
name = "fft_thunk_test",
srcs = ["fft_thunk_test.cc"],
deps = [
":fft_thunk",
":thunk",
":thunk_proto_cc",
"//xla/service:buffer_assignment",
"//xla/tsl/platform:statusor",
"//xla/tsl/util/proto:parse_text_proto",
"//xla/tsl/util/proto:proto_matchers",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
],
)
cc_library(
name = "gemm_thunk",
srcs = ["gemm_thunk.cc"],
@ -2397,6 +2415,7 @@ cc_library(
":convolution_thunk",
":copy_thunk",
":cudnn_thunk",
":fft_thunk",
":gemm_thunk",
":gpublas_lt_matmul_thunk",
":infeed_thunk",

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/log/check.h"
#include "absl/log/log.h"
@ -27,6 +29,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/buffer_assignment.pb.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
@ -81,6 +84,25 @@ std::string FftTypeToString(se::fft::Type type) {
}
}
absl::StatusOr<FftType> SeTypeToFftType(se::fft::Type type) {
switch (type) {
case se::fft::Type::kC2CForward:
case se::fft::Type::kZ2ZForward:
return FftType::FFT;
case se::fft::Type::kC2CInverse:
case se::fft::Type::kZ2ZInverse:
return FftType::IFFT;
case se::fft::Type::kC2R:
case se::fft::Type::kZ2D:
return FftType::IRFFT;
case se::fft::Type::kR2C:
case se::fft::Type::kD2Z:
return FftType::RFFT;
case se::fft::Type::kInvalid:
return Internal("Invalid fft type");
}
}
absl::StatusOr<stream_executor::blas::BlasSupport*> GetBlas(
se::Stream* stream) {
auto blas = stream->parent()->AsBlas();
@ -268,5 +290,47 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape,
FftTypeToString(fft_type));
}
absl::StatusOr<std::unique_ptr<FftThunk>> FftThunk::FromProto(
ThunkInfo thunk_info, const FftThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_buffer,
BufferAllocation::Slice::FromProto(proto.input_buffer(),
buffer_allocations));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer,
BufferAllocation::Slice::FromProto(proto.output_buffer(),
buffer_allocations));
TF_ASSIGN_OR_RETURN(Shape input_shape, Shape::FromProto(proto.input_shape()));
TF_ASSIGN_OR_RETURN(Shape output_shape,
Shape::FromProto(proto.output_shape()));
std::vector<int64_t> fft_length{proto.fft_length().begin(),
proto.fft_length().end()};
return std::make_unique<FftThunk>(thunk_info, proto.fft_type(),
std::move(fft_length), input_buffer,
output_buffer, input_shape, output_shape);
}
absl::StatusOr<ThunkProto> FftThunk::ToProto() const {
ThunkProto thunk_proto;
*thunk_proto.mutable_thunk_info() = thunk_info().ToProto();
FftThunkProto* proto = thunk_proto.mutable_fft_thunk();
TF_ASSIGN_OR_RETURN(FftType fft_type, SeTypeToFftType(fft_type_));
proto->set_fft_type(fft_type);
*proto->mutable_fft_length() = {fft_length_.begin(), fft_length_.end()};
TF_ASSIGN_OR_RETURN(*proto->mutable_input_buffer(), input_buffer_.ToProto());
TF_ASSIGN_OR_RETURN(*proto->mutable_output_buffer(),
output_buffer_.ToProto());
*proto->mutable_input_shape() = input_shape_.ToProto();
*proto->mutable_output_shape() = output_shape_.ToProto();
return thunk_proto;
}
} // namespace gpu
} // namespace xla

View File

@ -23,9 +23,11 @@ limitations under the License.
#include "absl/base/thread_annotations.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/synchronization/mutex.h"
#include "absl/types/span.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk.pb.h"
#include "xla/service/buffer_assignment.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_memory.h"
@ -82,6 +84,12 @@ class FftThunk : public Thunk {
// Does the FFT for the thunk on "stream".
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
static absl::StatusOr<std::unique_ptr<FftThunk>> FromProto(
ThunkInfo thunk_info, const FftThunkProto& proto,
absl::Span<const BufferAllocation> buffer_allocations);
absl::StatusOr<ThunkProto> ToProto() const override;
private:
const se::fft::Type fft_type_;
const std::vector<int64_t> fft_length_;

View File

@ -0,0 +1,84 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/backends/gpu/runtime/fft_thunk.h"
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/types/span.h"
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/backends/gpu/runtime/thunk.pb.h"
#include "xla/service/buffer_assignment.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/util/proto/parse_text_proto.h"
#include "xla/tsl/util/proto/proto_matchers.h"
namespace xla {
namespace gpu {
namespace {
using ::tsl::proto_testing::EqualsProto;
using ::tsl::proto_testing::ParseTextProtoOrDie;
TEST(FftThunkTest, ProtoRoundTrip) {
auto proto = ParseTextProtoOrDie<ThunkProto>(R"pb(
thunk_info { profile_annotation: "test" execution_stream_id: 0 }
fft_thunk {
fft_type: FFT
fft_length: [ 64, 64 ]
input_buffer { buffer_allocation_index: 0 offset: 0 size: 1024 }
output_buffer { buffer_allocation_index: 1 offset: 0 size: 1024 }
input_shape {
element_type: C64
dimensions: 1
dimensions: [ 64, 64 ]
layout {
minor_to_major: [ 2, 1, 0 ]
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: [ false, false, false ]
}
output_shape {
element_type: C64
dimensions: 1
dimensions: [ 64, 64 ]
layout {
minor_to_major: [ 2, 1, 0 ]
tail_padding_alignment_in_elements: 1
}
is_dynamic_dimension: [ false, false, false ]
}
}
)pb");
std::vector<BufferAllocation> buffer_allocations;
buffer_allocations.emplace_back(/*index=*/0, /*size=*/1024, /*color=*/0);
buffer_allocations.emplace_back(/*index=*/1, /*size=*/1024, /*color=*/0);
TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo thunk_info,
Thunk::ThunkInfo::FromProto(proto.thunk_info()));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<FftThunk> thunk,
FftThunk::FromProto(thunk_info, proto.fft_thunk(), buffer_allocations));
TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto());
EXPECT_THAT(round_trip_proto, EqualsProto(proto));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -213,6 +213,15 @@ message ConvolutionReorderThunkProto {
optional ConvolutionReorderBiasBuffers biases = 4;
}
message FftThunkProto {
FftType fft_type = 1;
repeated int64 fft_length = 2;
xla.buffer_assignment.BufferAllocationSliceProto input_buffer = 3;
xla.buffer_assignment.BufferAllocationSliceProto output_buffer = 4;
xla.ShapeProto input_shape = 5;
xla.ShapeProto output_shape = 6;
}
message ThunkProto {
ThunkInfoProto thunk_info = 1;
@ -242,6 +251,7 @@ message ThunkProto {
NormThunkProto norm_thunk = 24;
ConvolutionThunkProto convolution_thunk = 25;
ConvolutionReorderThunkProto convolution_reorder_thunk = 26;
FftThunkProto fft_thunk = 27;
}
}

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/convolution_thunk.h"
#include "xla/backends/gpu/runtime/copy_thunk.h"
#include "xla/backends/gpu/runtime/cudnn_thunk.h"
#include "xla/backends/gpu/runtime/fft_thunk.h"
#include "xla/backends/gpu/runtime/gemm_thunk.h"
#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h"
#include "xla/backends/gpu/runtime/infeed_thunk.h"
@ -159,6 +160,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> DeserializeThunkProto(
std::move(thunk_info), thunk_proto.convolution_reorder_thunk(),
buffer_allocations);
}
case ThunkProto::kFftThunk:
return FftThunk::FromProto(std::move(thunk_info), thunk_proto.fft_thunk(),
buffer_allocations);
default:
std::optional<absl::string_view> unsupported_thunk_type =
GetStoredThunkTypeName(thunk_proto);