[XLA:GPU] CustomCallThunk: enable use of lambdas with captures

Add CustomCallThunk::OwnedHandlerBundle, a bag of `unique_ptr<ffi::Ffi>` that
enable using lambdas with captures in CustomCallThunk. Lambda captures must
outlive the created thunk.

The functionality is similar to what is possible with "old-style" callbacks,
but doesn't depend on them, and adds support for other handlers available via
XLA_FFI_Handler_Bundle.

PiperOrigin-RevId: 826043689
This commit is contained in:
Marcin Radomski 2025-10-30 08:30:30 -07:00 committed by TensorFlower Gardener
parent 4461afa7ef
commit a0921d9997
4 changed files with 338 additions and 49 deletions

View File

@ -701,6 +701,7 @@ cc_library(
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:util",
"//xla/ffi",
"//xla/ffi:attribute_map",
"//xla/ffi:call_frame",
"//xla/ffi:execution_context",
@ -720,6 +721,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
@ -748,6 +750,7 @@ xla_test(
"//xla/service:executable",
"//xla/service:platform_util",
"//xla/service/gpu:buffer_allocations",
"//xla/service/gpu:resource_requests",
"//xla/stream_executor:platform",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor:stream",

View File

@ -22,11 +22,12 @@ limitations under the License.
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/base/nullability.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
@ -41,9 +42,11 @@ limitations under the License.
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/primitive_util.h"
#include "xla/runtime/object_pool.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_status_internal.h"
@ -250,6 +253,44 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
std::move(attributes), std::move(execution_state), called_computation));
}
absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
ThunkInfo thunk_info, std::string target_name, OwnedHandlerBundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results,
xla::ffi::AttributesMap attributes,
const HloComputation* called_computation) {
if (!bundle.execute) {
return absl::InvalidArgumentError(
"Execute handler is required for a CustomCallThunk");
}
auto execution_state = std::make_unique<ffi::ExecutionState>();
// Initialize FFI handler state if it has an instantiate callback.
if (bundle.instantiate) {
// At FFI handler instantiation time, we don't have any arguments or
// results or access to the underlying device (stream, etc.)
CallFrameBuilder builder(/*num_args=*/0, /*num_rets=*/0);
CallFrameBuilder::AttributesBuilder attrs;
attrs.Append(attributes);
builder.AddAttributes(attrs.Build());
CallFrame call_frame = builder.Build();
CallOptions options;
options.execution_state = execution_state.get();
TF_RETURN_IF_ERROR(Call(*bundle.instantiate, call_frame, options,
xla::ffi::ExecutionStage::kInstantiate));
}
TF_ASSIGN_OR_RETURN(CallFrame call_frame,
BuildCallFramePrototype(operands, results, attributes));
return absl::WrapUnique(new CustomCallThunk(
thunk_info, std::move(target_name), std::move(bundle),
std::move(operands), std::move(results), std::move(call_frame),
std::move(attributes), std::move(execution_state), called_computation));
}
CustomCallThunk::CustomCallThunk(
ThunkInfo thunk_info, std::string target_name,
std::vector<std::optional<ShapedSlice>> operands,
@ -266,7 +307,7 @@ CustomCallThunk::CustomCallThunk(
CustomCallThunk::CustomCallThunk(
ThunkInfo thunk_info, std::string target_name,
XLA_FFI_Handler_Bundle bundle,
std::variant<XLA_FFI_Handler_Bundle, OwnedHandlerBundle> bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, CallFrame call_frame,
ffi::AttributesMap attributes,
@ -317,18 +358,9 @@ absl::Status CustomCallThunk::ExecuteCustomCall(const ExecuteParams& params) {
return absl::OkStatus();
}
absl::Status CustomCallThunk::ExecuteFfiHandler(
RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
se::Stream* stream, const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations) {
if (handler == nullptr) {
return absl::InternalError("FFI execute handler is not set");
}
if (stage != XLA_FFI_ExecutionStage_PREPARE &&
!(buffer_allocations && stream)) {
return absl::InternalError("buffer allocations and stream are required");
}
absl::StatusOr<ObjectPool<CallFrame>::BorrowedObject>
CustomCallThunk::BuildCallFrame(
const BufferAllocations* absl_nullable buffer_allocations) {
auto device_memory = [&](BufferAllocation::Slice slice) {
return buffer_allocations ? buffer_allocations->GetDeviceAddress(slice)
: se::DeviceMemoryBase{};
@ -360,58 +392,142 @@ absl::Status CustomCallThunk::ExecuteFfiHandler(
// device memory addresses.
TF_ASSIGN_OR_RETURN(auto call_frame, call_frames_->GetOrCreate());
TF_RETURN_IF_ERROR(call_frame->UpdateWithBuffers(arguments, results));
return call_frame;
}
CallOptions CustomCallThunk::BuildCallOptions(
RunId run_id, se::Stream* absl_nullable stream,
const BufferAllocations* absl_nullable buffer_allocations,
const ffi::ExecutionContext* absl_nonnull execution_context) {
int32_t device_ordinal = -1;
se::DeviceMemoryAllocator* allocator = nullptr;
if (stage != XLA_FFI_ExecutionStage_PREPARE) {
if (buffer_allocations != nullptr) {
device_ordinal = buffer_allocations->device_ordinal();
allocator = buffer_allocations->memory_allocator();
}
CallOptions options = {run_id,
device_ordinal,
CallOptions::GpuOptions{stream, allocator},
called_computation_,
execution_context,
execution_state_.get()};
return CallOptions{run_id,
device_ordinal,
CallOptions::GpuOptions{stream, allocator},
called_computation_,
execution_context,
execution_state_.get()};
}
absl::Status CustomCallThunk::ExecuteFfiHandler(
RunId run_id, XLA_FFI_Handler* handler, XLA_FFI_ExecutionStage stage,
se::Stream* stream, const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations) {
if (handler == nullptr) {
return absl::InternalError("FFI execute handler is not set");
}
if (stage != XLA_FFI_ExecutionStage_PREPARE &&
!(buffer_allocations && stream)) {
return absl::InternalError("buffer allocations and stream are required");
}
TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations));
CallOptions options =
BuildCallOptions(run_id, stream, buffer_allocations, execution_context);
return Call(handler, *call_frame, options, stage);
}
absl::Status CustomCallThunk::ExecuteFfiHandler(
RunId run_id, xla::ffi::Ffi& handler, xla::ffi::ExecutionStage stage,
se::Stream* stream, const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations) {
if (stage != xla::ffi::ExecutionStage::kPrepare &&
!(buffer_allocations && stream)) {
return absl::InternalError("buffer allocations and stream are required");
}
TF_ASSIGN_OR_RETURN(auto call_frame, BuildCallFrame(buffer_allocations));
CallOptions options =
BuildCallOptions(run_id, stream, buffer_allocations, execution_context);
return Call(handler, *call_frame, options, stage);
}
absl::Status CustomCallThunk::Prepare(
const PrepareParams& params, ResourceRequestsInterface& resource_requests) {
if (!bundle_ || !bundle_->prepare) {
return absl::OkStatus();
if (bundle_.has_value()) {
const RunId run_id =
params.collective_params ? params.collective_params->run_id : RunId{-1};
if (const auto* c_bundle =
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
c_bundle && c_bundle->prepare) {
return ExecuteFfiHandler(run_id, c_bundle->prepare,
XLA_FFI_ExecutionStage_PREPARE,
/*stream=*/nullptr,
/*execution_context=*/nullptr,
/*buffer_allocations=*/nullptr);
}
if (const auto* owned_bundle =
std::get_if<OwnedHandlerBundle>(&bundle_.value());
owned_bundle && owned_bundle->prepare) {
return ExecuteFfiHandler(run_id, *owned_bundle->prepare,
xla::ffi::ExecutionStage::kPrepare,
/*stream=*/nullptr,
/*execution_context=*/nullptr,
/*buffer_allocations=*/nullptr);
}
}
return ExecuteFfiHandler(
params.collective_params ? params.collective_params->run_id : RunId{-1},
bundle_->prepare, XLA_FFI_ExecutionStage_PREPARE,
/*stream=*/nullptr,
/*execution_context=*/nullptr,
/*buffer_allocations=*/nullptr);
return absl::OkStatus();
}
absl::Status CustomCallThunk::Initialize(const InitializeParams& params) {
if (!bundle_ || !bundle_->initialize) {
return absl::OkStatus();
}
if (bundle_.has_value()) {
const RunId run_id =
params.collective_params ? params.collective_params->run_id : RunId{-1};
return ExecuteFfiHandler(
params.collective_params ? params.collective_params->run_id : RunId{-1},
bundle_->initialize, XLA_FFI_ExecutionStage_INITIALIZE, params.stream,
params.ffi_execution_context, params.buffer_allocations);
if (const auto* c_bundle =
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
c_bundle && c_bundle->initialize) {
return ExecuteFfiHandler(run_id, *c_bundle->initialize,
XLA_FFI_ExecutionStage_INITIALIZE, params.stream,
params.ffi_execution_context,
params.buffer_allocations);
}
if (const auto* owned_bundle =
std::get_if<OwnedHandlerBundle>(&bundle_.value());
owned_bundle && owned_bundle->initialize) {
return ExecuteFfiHandler(run_id, *owned_bundle->initialize,
xla::ffi::ExecutionStage::kInitialize,
params.stream, params.ffi_execution_context,
params.buffer_allocations);
}
}
return absl::OkStatus();
}
absl::Status CustomCallThunk::ExecuteOnStream(const ExecuteParams& params) {
TF_ASSIGN_OR_RETURN(
se::Stream * stream,
GetStreamForExecution(Thunk::execution_stream_id(), params));
if (bundle_.has_value()) {
return ExecuteFfiHandler(
params.collective_params ? params.collective_params->run_id : RunId{-1},
bundle_->execute, XLA_FFI_ExecutionStage_EXECUTE, stream,
params.ffi_execution_context, params.buffer_allocations);
const RunId run_id =
params.collective_params ? params.collective_params->run_id : RunId{-1};
if (const auto* c_bundle =
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
c_bundle) {
return ExecuteFfiHandler(
run_id, c_bundle->execute, XLA_FFI_ExecutionStage_EXECUTE, stream,
params.ffi_execution_context, params.buffer_allocations);
}
if (const auto* owned_bundle =
std::get_if<OwnedHandlerBundle>(&bundle_.value());
owned_bundle) {
if (!owned_bundle->execute) {
return absl::InternalError("FFI execute handler is not set");
}
return ExecuteFfiHandler(
run_id, *owned_bundle->execute, xla::ffi::ExecutionStage::kExecute,
stream, params.ffi_execution_context, params.buffer_allocations);
}
}
return ExecuteCustomCall(params);
}

View File

@ -21,8 +21,10 @@ limitations under the License.
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
@ -33,6 +35,8 @@ limitations under the License.
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/runtime/object_pool.h"
#include "xla/service/custom_call_status.h"
@ -56,6 +60,17 @@ namespace gpu {
// compiler is allowed to create.
class CustomCallThunk : public Thunk {
public:
// An owning equivalent of XLA_FFI_Handler_Bundle that allows using lambdas
// with captures.
//
// The members can be initialized with xla::ffi::Ffi::Bind().To(...).
struct OwnedHandlerBundle {
std::unique_ptr<xla::ffi::Ffi> initialize;
std::unique_ptr<xla::ffi::Ffi> instantiate;
std::unique_ptr<xla::ffi::Ffi> prepare;
std::unique_ptr<xla::ffi::Ffi> execute;
};
using CustomCallTarget =
std::function<void(stream_executor::Stream*, void**, const char*, size_t,
XlaCustomCallStatus*)>;
@ -98,6 +113,16 @@ class CustomCallThunk : public Thunk {
xla::ffi::AttributesMap attributes,
const HloComputation* called_computation);
// Creates a custom call thunk from a bundle of handlers created with
// xla::ffi::Bind(). Any pointer or reference lambda captures must be valid
// for the lifetime of the thunk.
static absl::StatusOr<std::unique_ptr<CustomCallThunk>> Create(
ThunkInfo thunk_info, std::string target_name, OwnedHandlerBundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results,
xla::ffi::AttributesMap attributes,
const HloComputation* called_computation);
absl::Status Prepare(const PrepareParams& params,
ResourceRequestsInterface& resource_requests) override;
absl::Status Initialize(const InitializeParams& params) override;
@ -105,7 +130,14 @@ class CustomCallThunk : public Thunk {
const std::string& target_name() const { return target_name_; }
CustomCallTarget call_target() const { return call_target_; }
std::optional<XLA_FFI_Handler_Bundle> bundle() const { return bundle_; }
std::optional<XLA_FFI_Handler_Bundle> bundle() const {
if (!bundle_.has_value()) {
return std::nullopt;
}
const XLA_FFI_Handler_Bundle* c_bundle =
std::get_if<XLA_FFI_Handler_Bundle>(&bundle_.value());
return c_bundle ? std::make_optional(*c_bundle) : std::nullopt;
}
std::optional<ffi::CallFrame> call_frame() const {
return call_frame_ ? std::make_optional(call_frame_->Copy()) : std::nullopt;
}
@ -126,22 +158,37 @@ class CustomCallThunk : public Thunk {
std::string opaque, CustomCallTarget call_target,
const std::optional<CustomCallApiVersion>& api_version);
CustomCallThunk(ThunkInfo thunk_info, std::string target_name,
XLA_FFI_Handler_Bundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results,
ffi::CallFrame call_frame, xla::ffi::AttributesMap attributes,
std::unique_ptr<ffi::ExecutionState> execution_state,
const HloComputation* called_computation);
CustomCallThunk(
ThunkInfo thunk_info, std::string target_name,
std::variant<XLA_FFI_Handler_Bundle, OwnedHandlerBundle> bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results,
ffi::CallFrame call_frame, xla::ffi::AttributesMap attributes,
std::unique_ptr<ffi::ExecutionState> execution_state,
const HloComputation* called_computation);
absl::Status ExecuteCustomCall(const ExecuteParams& params);
absl::StatusOr<ObjectPool<xla::ffi::CallFrame>::BorrowedObject>
BuildCallFrame(const BufferAllocations* absl_nullable buffer_allocations);
xla::ffi::CallOptions BuildCallOptions(
RunId run_id, se::Stream* absl_nullable stream,
const BufferAllocations* absl_nullable buffer_allocations,
const ffi::ExecutionContext* absl_nonnull execution_context);
absl::Status ExecuteFfiHandler(RunId run_id, XLA_FFI_Handler* handler,
XLA_FFI_ExecutionStage stage,
se::Stream* stream,
const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations);
absl::Status ExecuteFfiHandler(RunId run_id, xla::ffi::Ffi& handler,
xla::ffi::ExecutionStage stage,
se::Stream* stream,
const ffi::ExecutionContext* execution_context,
const BufferAllocations* buffer_allocations);
// API version of the custom call. If not set, it means the custom call thunk
// was initialized from a non-registered function pointer and can't be
// serialized to a proto.
@ -159,7 +206,8 @@ class CustomCallThunk : public Thunk {
// XLA FFI provides a right type safe mechanism for registering external
// functions with XLA runtime. It's under construction, and still misses
// a lot of features. Long term it will replace legacy custom calls.
std::optional<XLA_FFI_Handler_Bundle> bundle_;
std::optional<std::variant<XLA_FFI_Handler_Bundle, OwnedHandlerBundle>>
bundle_;
std::optional<xla::ffi::AttributesMap> attributes_;
// Reference call frame pre-initialized at construction time.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef>
#include <memory>
#include <string>
#include <utility>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@ -34,6 +35,7 @@ limitations under the License.
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
#include "xla/service/gpu/buffer_allocations.h"
#include "xla/service/gpu/resource_requests.h"
#include "xla/service/platform_util.h"
#include "xla/service/service_executable_run_options.h"
#include "xla/stream_executor/platform.h"
@ -44,6 +46,7 @@ limitations under the License.
namespace xla::gpu {
namespace {
using absl_testing::IsOk;
using absl_testing::StatusIs;
using ::testing::HasSubstr;
@ -198,5 +201,124 @@ TEST(CustomCallThunkTest, ResolvesLegacyCustomCall) {
HasSubstr("Legacy Custom call was executed!")));
}
TEST(CustomCallThunkTest, CustomCallWithOwnedHandlers) {
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executor->CreateStream());
int instantiate_calls = 0;
int prepare_calls = 0;
int initialize_calls = 0;
int execute_calls = 0;
CustomCallThunk::OwnedHandlerBundle bundle;
bundle.instantiate =
ffi::Ffi::Bind<ffi::ExecutionStage::kInstantiate>().To([&]() {
++instantiate_calls;
return absl::OkStatus();
});
bundle.prepare = ffi::Ffi::Bind<ffi::ExecutionStage::kPrepare>().To([&]() {
++prepare_calls;
return absl::OkStatus();
});
bundle.initialize =
ffi::Ffi::Bind<ffi::ExecutionStage::kInitialize>().To([&]() {
++initialize_calls;
return absl::OkStatus();
});
bundle.execute = ffi::Ffi::Bind<ffi::ExecutionStage::kExecute>().To([&]() {
++execute_calls;
return absl::OkStatus();
});
se::StreamExecutorMemoryAllocator allocator(executor);
Thunk::PrepareParams prepare_params = Thunk::PrepareParams{};
ResourceRequests resource_requests;
BufferAllocations buffer_allocations({}, 0, &allocator);
Thunk::InitializeParams initialize_params;
initialize_params.stream = stream.get();
initialize_params.buffer_allocations = &buffer_allocations;
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator),
stream.get(), stream.get(), nullptr, nullptr);
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CustomCallThunk> thunk,
CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name",
std::move(bundle),
/*operands=*/{},
/*results=*/{}, /*attributes=*/{},
/*called_computation=*/nullptr));
EXPECT_EQ(instantiate_calls, 1);
EXPECT_EQ(prepare_calls, 0);
EXPECT_EQ(initialize_calls, 0);
EXPECT_EQ(execute_calls, 0);
EXPECT_THAT(thunk->Prepare(prepare_params, resource_requests), IsOk());
EXPECT_EQ(instantiate_calls, 1);
EXPECT_EQ(prepare_calls, 1);
EXPECT_EQ(initialize_calls, 0);
EXPECT_EQ(execute_calls, 0);
EXPECT_THAT(thunk->Initialize(initialize_params), IsOk());
EXPECT_EQ(instantiate_calls, 1);
EXPECT_EQ(prepare_calls, 1);
EXPECT_EQ(initialize_calls, 1);
EXPECT_EQ(execute_calls, 0);
EXPECT_THAT(thunk->ExecuteOnStream(execute_params), IsOk());
EXPECT_EQ(initialize_calls, 1);
EXPECT_EQ(instantiate_calls, 1);
EXPECT_EQ(prepare_calls, 1);
EXPECT_EQ(execute_calls, 1);
}
TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutOptionalOnes) {
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executor->CreateStream());
int execute_calls = 0;
CustomCallThunk::OwnedHandlerBundle bundle;
bundle.execute = ffi::Ffi::Bind().To([&]() {
++execute_calls;
return absl::OkStatus();
});
se::StreamExecutorMemoryAllocator allocator(executor);
Thunk::PrepareParams prepare_params = Thunk::PrepareParams{};
ResourceRequests resource_requests;
Thunk::InitializeParams initialize_params = Thunk::InitializeParams{};
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator),
stream.get(), stream.get(), nullptr, nullptr);
// Optional handlers are null and shouldn't be invoked.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<CustomCallThunk> thunk,
CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name",
std::move(bundle),
/*operands=*/{},
/*results=*/{}, /*attributes=*/{},
/*called_computation=*/nullptr));
EXPECT_THAT(thunk->Prepare(prepare_params, resource_requests), IsOk());
EXPECT_THAT(thunk->Initialize(initialize_params), IsOk());
EXPECT_THAT(thunk->ExecuteOnStream(execute_params), IsOk());
EXPECT_EQ(execute_calls, 1);
}
TEST(CustomCallThunkTest, CustomCallWithOwnedHandlersWithoutExecute) {
TF_ASSERT_OK_AND_ASSIGN(se::StreamExecutor * executor, GpuExecutor());
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<se::Stream> stream,
executor->CreateStream());
CustomCallThunk::OwnedHandlerBundle bundle; // all handlers null
se::StreamExecutorMemoryAllocator allocator(executor);
Thunk::ExecuteParams execute_params = Thunk::ExecuteParams::Create(
ServiceExecutableRunOptions(), BufferAllocations({}, 0, &allocator),
stream.get(), stream.get(), nullptr, nullptr);
EXPECT_THAT(CustomCallThunk::Create(Thunk::ThunkInfo(), "target_name",
std::move(bundle),
/*operands=*/{},
/*results=*/{}, /*attributes=*/{},
/*called_computation=*/nullptr),
StatusIs(absl::StatusCode::kInvalidArgument));
}
} // namespace
} // namespace xla::gpu