mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add (de)serialization for FftThunk
This one is a pretty direct mapping from the struct to the proto. PiperOrigin-RevId: 819214943
This commit is contained in:
parent
0064d2d1bb
commit
d555ed2c74
19
third_party/xla/xla/backends/gpu/runtime/BUILD
vendored
19
third_party/xla/xla/backends/gpu/runtime/BUILD
vendored
|
|
@ -744,12 +744,14 @@ cc_library(
|
|||
hdrs = ["fft_thunk.h"],
|
||||
deps = [
|
||||
":thunk",
|
||||
":thunk_proto_cc",
|
||||
"//xla:shape_util",
|
||||
"//xla:status_macros",
|
||||
"//xla:types",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/service:buffer_assignment_proto_cc",
|
||||
"//xla/stream_executor:blas",
|
||||
"//xla/stream_executor:device_memory",
|
||||
"//xla/stream_executor:device_memory_allocator",
|
||||
|
|
@ -768,6 +770,22 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "fft_thunk_test",
|
||||
srcs = ["fft_thunk_test.cc"],
|
||||
deps = [
|
||||
":fft_thunk",
|
||||
":thunk",
|
||||
":thunk_proto_cc",
|
||||
"//xla/service:buffer_assignment",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"//xla/tsl/util/proto:parse_text_proto",
|
||||
"//xla/tsl/util/proto:proto_matchers",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemm_thunk",
|
||||
srcs = ["gemm_thunk.cc"],
|
||||
|
|
@ -2397,6 +2415,7 @@ cc_library(
|
|||
":convolution_thunk",
|
||||
":copy_thunk",
|
||||
":cudnn_thunk",
|
||||
":fft_thunk",
|
||||
":gemm_thunk",
|
||||
":gpublas_lt_matmul_thunk",
|
||||
":infeed_thunk",
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ limitations under the License.
|
|||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/log/check.h"
|
||||
#include "absl/log/log.h"
|
||||
|
|
@ -27,6 +29,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/buffer_assignment.pb.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/status_macros.h"
|
||||
|
|
@ -81,6 +84,25 @@ std::string FftTypeToString(se::fft::Type type) {
|
|||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<FftType> SeTypeToFftType(se::fft::Type type) {
|
||||
switch (type) {
|
||||
case se::fft::Type::kC2CForward:
|
||||
case se::fft::Type::kZ2ZForward:
|
||||
return FftType::FFT;
|
||||
case se::fft::Type::kC2CInverse:
|
||||
case se::fft::Type::kZ2ZInverse:
|
||||
return FftType::IFFT;
|
||||
case se::fft::Type::kC2R:
|
||||
case se::fft::Type::kZ2D:
|
||||
return FftType::IRFFT;
|
||||
case se::fft::Type::kR2C:
|
||||
case se::fft::Type::kD2Z:
|
||||
return FftType::RFFT;
|
||||
case se::fft::Type::kInvalid:
|
||||
return Internal("Invalid fft type");
|
||||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<stream_executor::blas::BlasSupport*> GetBlas(
|
||||
se::Stream* stream) {
|
||||
auto blas = stream->parent()->AsBlas();
|
||||
|
|
@ -268,5 +290,47 @@ absl::Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape,
|
|||
FftTypeToString(fft_type));
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<FftThunk>> FftThunk::FromProto(
|
||||
ThunkInfo thunk_info, const FftThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations) {
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice input_buffer,
|
||||
BufferAllocation::Slice::FromProto(proto.input_buffer(),
|
||||
buffer_allocations));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_buffer,
|
||||
BufferAllocation::Slice::FromProto(proto.output_buffer(),
|
||||
buffer_allocations));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape input_shape, Shape::FromProto(proto.input_shape()));
|
||||
TF_ASSIGN_OR_RETURN(Shape output_shape,
|
||||
Shape::FromProto(proto.output_shape()));
|
||||
|
||||
std::vector<int64_t> fft_length{proto.fft_length().begin(),
|
||||
proto.fft_length().end()};
|
||||
|
||||
return std::make_unique<FftThunk>(thunk_info, proto.fft_type(),
|
||||
std::move(fft_length), input_buffer,
|
||||
output_buffer, input_shape, output_shape);
|
||||
}
|
||||
|
||||
absl::StatusOr<ThunkProto> FftThunk::ToProto() const {
|
||||
ThunkProto thunk_proto;
|
||||
*thunk_proto.mutable_thunk_info() = thunk_info().ToProto();
|
||||
|
||||
FftThunkProto* proto = thunk_proto.mutable_fft_thunk();
|
||||
TF_ASSIGN_OR_RETURN(FftType fft_type, SeTypeToFftType(fft_type_));
|
||||
proto->set_fft_type(fft_type);
|
||||
|
||||
*proto->mutable_fft_length() = {fft_length_.begin(), fft_length_.end()};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*proto->mutable_input_buffer(), input_buffer_.ToProto());
|
||||
TF_ASSIGN_OR_RETURN(*proto->mutable_output_buffer(),
|
||||
output_buffer_.ToProto());
|
||||
|
||||
*proto->mutable_input_shape() = input_shape_.ToProto();
|
||||
*proto->mutable_output_shape() = output_shape_.ToProto();
|
||||
|
||||
return thunk_proto;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
|
|
@ -23,9 +23,11 @@ limitations under the License.
|
|||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.pb.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
|
|
@ -82,6 +84,12 @@ class FftThunk : public Thunk {
|
|||
// Does the FFT for the thunk on "stream".
|
||||
absl::Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
static absl::StatusOr<std::unique_ptr<FftThunk>> FromProto(
|
||||
ThunkInfo thunk_info, const FftThunkProto& proto,
|
||||
absl::Span<const BufferAllocation> buffer_allocations);
|
||||
|
||||
absl::StatusOr<ThunkProto> ToProto() const override;
|
||||
|
||||
private:
|
||||
const se::fft::Type fft_type_;
|
||||
const std::vector<int64_t> fft_length_;
|
||||
|
|
|
|||
84
third_party/xla/xla/backends/gpu/runtime/fft_thunk_test.cc
vendored
Normal file
84
third_party/xla/xla/backends/gpu/runtime/fft_thunk_test.cc
vendored
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/backends/gpu/runtime/fft_thunk.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/backends/gpu/runtime/thunk.pb.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
#include "xla/tsl/util/proto/parse_text_proto.h"
|
||||
#include "xla/tsl/util/proto/proto_matchers.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
using ::tsl::proto_testing::EqualsProto;
|
||||
using ::tsl::proto_testing::ParseTextProtoOrDie;
|
||||
|
||||
TEST(FftThunkTest, ProtoRoundTrip) {
|
||||
auto proto = ParseTextProtoOrDie<ThunkProto>(R"pb(
|
||||
thunk_info { profile_annotation: "test" execution_stream_id: 0 }
|
||||
fft_thunk {
|
||||
fft_type: FFT
|
||||
fft_length: [ 64, 64 ]
|
||||
input_buffer { buffer_allocation_index: 0 offset: 0 size: 1024 }
|
||||
output_buffer { buffer_allocation_index: 1 offset: 0 size: 1024 }
|
||||
input_shape {
|
||||
element_type: C64
|
||||
dimensions: 1
|
||||
dimensions: [ 64, 64 ]
|
||||
layout {
|
||||
minor_to_major: [ 2, 1, 0 ]
|
||||
tail_padding_alignment_in_elements: 1
|
||||
}
|
||||
is_dynamic_dimension: [ false, false, false ]
|
||||
}
|
||||
output_shape {
|
||||
element_type: C64
|
||||
dimensions: 1
|
||||
dimensions: [ 64, 64 ]
|
||||
layout {
|
||||
minor_to_major: [ 2, 1, 0 ]
|
||||
tail_padding_alignment_in_elements: 1
|
||||
}
|
||||
is_dynamic_dimension: [ false, false, false ]
|
||||
}
|
||||
}
|
||||
)pb");
|
||||
|
||||
std::vector<BufferAllocation> buffer_allocations;
|
||||
buffer_allocations.emplace_back(/*index=*/0, /*size=*/1024, /*color=*/0);
|
||||
buffer_allocations.emplace_back(/*index=*/1, /*size=*/1024, /*color=*/0);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(Thunk::ThunkInfo thunk_info,
|
||||
Thunk::ThunkInfo::FromProto(proto.thunk_info()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<FftThunk> thunk,
|
||||
FftThunk::FromProto(thunk_info, proto.fft_thunk(), buffer_allocations));
|
||||
TF_ASSERT_OK_AND_ASSIGN(ThunkProto round_trip_proto, thunk->ToProto());
|
||||
EXPECT_THAT(round_trip_proto, EqualsProto(proto));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
@ -213,6 +213,15 @@ message ConvolutionReorderThunkProto {
|
|||
optional ConvolutionReorderBiasBuffers biases = 4;
|
||||
}
|
||||
|
||||
message FftThunkProto {
|
||||
FftType fft_type = 1;
|
||||
repeated int64 fft_length = 2;
|
||||
xla.buffer_assignment.BufferAllocationSliceProto input_buffer = 3;
|
||||
xla.buffer_assignment.BufferAllocationSliceProto output_buffer = 4;
|
||||
xla.ShapeProto input_shape = 5;
|
||||
xla.ShapeProto output_shape = 6;
|
||||
}
|
||||
|
||||
message ThunkProto {
|
||||
ThunkInfoProto thunk_info = 1;
|
||||
|
||||
|
|
@ -242,6 +251,7 @@ message ThunkProto {
|
|||
NormThunkProto norm_thunk = 24;
|
||||
ConvolutionThunkProto convolution_thunk = 25;
|
||||
ConvolutionReorderThunkProto convolution_reorder_thunk = 26;
|
||||
FftThunkProto fft_thunk = 27;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/convolution_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/copy_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/cudnn_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/fft_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/gemm_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/gpublas_lt_matmul_thunk.h"
|
||||
#include "xla/backends/gpu/runtime/infeed_thunk.h"
|
||||
|
|
@ -159,6 +160,9 @@ absl::StatusOr<std::unique_ptr<Thunk>> DeserializeThunkProto(
|
|||
std::move(thunk_info), thunk_proto.convolution_reorder_thunk(),
|
||||
buffer_allocations);
|
||||
}
|
||||
case ThunkProto::kFftThunk:
|
||||
return FftThunk::FromProto(std::move(thunk_info), thunk_proto.fft_thunk(),
|
||||
buffer_allocations);
|
||||
default:
|
||||
std::optional<absl::string_view> unsupported_thunk_type =
|
||||
GetStoredThunkTypeName(thunk_proto);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user