Move Attribute types from call_frame.cc into attribute_map.cc

This is moving `Scalar`, `Array`, `Dictionary`, `FlatAttribute`, `FlatAttributeMap`, and `AttributeMap` from `CallFrameBuilder` into the `xla::ffi` namespace.

It also moves the code into `attribute_map.{cc|h}`.

All these types are basically aliases from some kind of `std::variant` type. This change is a preparation for making them proper types and add `ToProto` and `FromProto` methods.

PiperOrigin-RevId: 824435281
This commit is contained in:
Henning Becker 2025-10-27 03:13:34 -07:00 committed by TensorFlower Gardener
parent 78a0ca0b60
commit 76a084f181
17 changed files with 154 additions and 129 deletions

View File

@ -60,7 +60,7 @@ limitations under the License.
namespace xla::cpu {
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;
using AttributesMap = ffi::AttributesMap;
static absl::StatusOr<AttributesMap> ParseAttributes(
absl::string_view backend_config) {

View File

@ -878,7 +878,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
// For XLA FFI handlers we decode opaque backend config into attributes map
// at IR emission time, so that we do not need to parse MLIR at run time.
// For FFI handlers backend config must be a compatible MLIR dictionary.
CustomCallThunk::AttributesMap attributes;
ffi::AttributesMap attributes;
// For information about this calling convention, see
// xla/g3doc/custom_call.md.

View File

@ -706,6 +706,7 @@ cc_library(
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:util",
"//xla/ffi:attribute_map",
"//xla/ffi:call_frame",
"//xla/ffi:execution_context",
"//xla/ffi:execution_state",

View File

@ -989,7 +989,7 @@ class CuDnnCmd : public TracedCommandBufferCmd {
class CustomCallCmd : public CommandBufferCmd {
public:
using CustomCallTarget = CustomCallThunk::CustomCallTarget;
using AttributesMap = CustomCallThunk::AttributesMap;
using AttributesMap = ffi::AttributesMap;
// This is a legacy custom call API that is discouraged, and will be
// deprecated once XLA:FFI mechanism is ready.

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_state.h"
#include "xla/ffi/ffi_api.h"
@ -69,7 +70,7 @@ using xla::ffi::CallOptions;
static absl::StatusOr<ffi::CallFrame> BuildCallFramePrototype(
absl::Span<const std::optional<ShapedSlice>> operands,
absl::Span<const std::optional<ShapedSlice>> results,
CustomCallThunk::AttributesMap attributes) {
ffi::AttributesMap attributes) {
CallFrameBuilder builder(
/*num_args=*/operands.size(),
/*num_rets=*/results.size());
@ -204,8 +205,9 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
ThunkInfo thunk_info, std::string target_name,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
const HloComputation* called_computation, absl::string_view platform_name) {
std::vector<std::optional<ShapedSlice>> results,
ffi::AttributesMap attributes, const HloComputation* called_computation,
absl::string_view platform_name) {
TF_ASSIGN_OR_RETURN(ffi::HandlerRegistration registration,
ffi::FindHandler(target_name, platform_name));
@ -218,8 +220,8 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
ThunkInfo thunk_info, std::string target_name,
XLA_FFI_Handler_Bundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
const HloComputation* called_computation) {
std::vector<std::optional<ShapedSlice>> results,
ffi::AttributesMap attributes, const HloComputation* called_computation) {
auto execution_state = std::make_unique<ffi::ExecutionState>();
// Initialize FFI handler state if it has an instantiate callback.
@ -267,7 +269,7 @@ CustomCallThunk::CustomCallThunk(
XLA_FFI_Handler_Bundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, CallFrame call_frame,
AttributesMap attributes,
ffi::AttributesMap attributes,
std::unique_ptr<ffi::ExecutionState> execution_state,
const HloComputation* called_computation)
: Thunk(Thunk::kCustomCall, thunk_info),

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "xla/backends/gpu/runtime/thunk.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
@ -59,9 +60,6 @@ class CustomCallThunk : public Thunk {
std::function<void(stream_executor::Stream*, void**, const char*, size_t,
XlaCustomCallStatus*)>;
using Attribute = ffi::CallFrameBuilder::Attribute;
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;
// Creates a serializable custom call thunk. The callback is resolved using
// the legacy CustomCall registry. For new code please use XLA FFI instead.
static absl::StatusOr<std::unique_ptr<CustomCallThunk>> Create(
@ -84,7 +82,8 @@ class CustomCallThunk : public Thunk {
static absl::StatusOr<std::unique_ptr<CustomCallThunk>> Create(
ThunkInfo thunk_info, std::string target_name,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
std::vector<std::optional<ShapedSlice>> results,
xla::ffi::AttributesMap attributes,
const HloComputation* called_computation,
absl::string_view platform_name);
@ -95,7 +94,8 @@ class CustomCallThunk : public Thunk {
ThunkInfo thunk_info, std::string target_name,
XLA_FFI_Handler_Bundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
std::vector<std::optional<ShapedSlice>> results,
xla::ffi::AttributesMap attributes,
const HloComputation* called_computation);
absl::Status Prepare(const PrepareParams& params,
@ -130,7 +130,7 @@ class CustomCallThunk : public Thunk {
XLA_FFI_Handler_Bundle bundle,
std::vector<std::optional<ShapedSlice>> operands,
std::vector<std::optional<ShapedSlice>> results,
ffi::CallFrame call_frame, AttributesMap attributes,
ffi::CallFrame call_frame, xla::ffi::AttributesMap attributes,
std::unique_ptr<ffi::ExecutionState> execution_state,
const HloComputation* called_computation);
@ -160,7 +160,7 @@ class CustomCallThunk : public Thunk {
// functions with XLA runtime. It's under construction, and still misses
// a lot of features. Long term it will replace legacy custom calls.
std::optional<XLA_FFI_Handler_Bundle> bundle_;
std::optional<AttributesMap> attributes_;
std::optional<xla::ffi::AttributesMap> attributes_;
// Reference call frame pre-initialized at construction time.
std::optional<ffi::CallFrame> call_frame_;

View File

@ -570,7 +570,7 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpy) {
seq.emplace_back(),
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
registration.bundle, operands, results,
/*attributes=*/CustomCallThunk::AttributesMap(),
/*attributes=*/ffi::AttributesMap(),
/*called_computation=*/nullptr));
// Wrapping dynamic slice thunk around the custom call thunk.
@ -731,7 +731,7 @@ TEST_F(DynamicSliceThunkTest, SlicedOutputMemcpy) {
seq.emplace_back(),
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
registration.bundle, operands, results,
/*attributes=*/CustomCallThunk::AttributesMap(),
/*attributes=*/ffi::AttributesMap(),
/*called_computation=*/nullptr));
// Wrapping dynamic slice thunk around the custom call thunk.
@ -1451,7 +1451,7 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpyOOB) {
seq.emplace_back(),
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
registration.bundle, operands, results,
/*attributes=*/CustomCallThunk::AttributesMap(),
/*attributes=*/ffi::AttributesMap(),
/*called_computation=*/nullptr));
// Wrapping dynamic slice thunk around the custom call thunk.

View File

@ -22,6 +22,7 @@ cc_library(
hdrs = ["call_frame.h"],
deps = [
":api",
":attribute_map",
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
@ -44,6 +45,7 @@ xla_cc_test(
name = "call_frame_test",
srcs = ["call_frame_test.cc"],
deps = [
":attribute_map",
":call_frame",
"//xla:xla_data_proto_cc",
"//xla/ffi/api:c_api",
@ -196,7 +198,9 @@ cc_library(
srcs = ["attribute_map.cc"],
hdrs = ["attribute_map.h"],
deps = [
":call_frame",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
@ -204,8 +208,6 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)
@ -216,6 +218,7 @@ xla_cc_test(
features = ["-use_header_modules"],
shuffle_tests = False,
deps = [
":attribute_map",
":call_frame",
":execution_context",
":execution_state",

View File

@ -85,6 +85,7 @@ xla_cc_test(
"//xla:executable_run_options",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/ffi:attribute_map",
"//xla/ffi:call_frame",
"//xla/ffi:execution_context",
"//xla/ffi:execution_state",

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "absl/synchronization/blocking_counter.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
@ -117,6 +118,7 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
namespace xla::ffi {
using ::absl_testing::StatusIs;
using ::testing::HasSubstr;
TEST(FfiTest, DataTypeEnumValue) {
@ -577,9 +579,8 @@ TEST(FfiTest, MissingBufferArgument) {
[](auto) { return Error::Success(); });
auto status = Call(*handler, call_frame);
EXPECT_THAT(status,
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong number of arguments")));
EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong number of arguments")));
}
TEST(FfiTest, WrongRankBufferArgument) {
@ -595,9 +596,8 @@ TEST(FfiTest, WrongRankBufferArgument) {
auto status = Call(*handler, call_frame);
EXPECT_THAT(status,
absl_testing::StatusIs(
absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong buffer rank: expected 1 but got 2")));
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong buffer rank: expected 1 but got 2")));
}
TEST(FfiTest, WrongTypeBufferArgument) {
@ -612,10 +612,10 @@ TEST(FfiTest, WrongTypeBufferArgument) {
[](auto) { return Error::Success(); });
auto status = Call(*handler, call_frame);
EXPECT_THAT(status,
absl_testing::StatusIs(
absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong buffer dtype: expected F32 but got S32")));
EXPECT_THAT(
status,
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong buffer dtype: expected F32 but got S32")));
}
TEST(FfiTest, WrongNumberOfArguments) {
@ -631,9 +631,8 @@ TEST(FfiTest, WrongNumberOfArguments) {
Ffi::Bind().Attr<int>("foo").To([](int foo) { return Error::Success(); });
auto status = Call(*handler, call_frame);
EXPECT_THAT(status,
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong number of attributes")));
EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Wrong number of attributes")));
EXPECT_THAT(status.message(), HasSubstr("foo"));
EXPECT_THAT(status.message(), HasSubstr("bar"));
}
@ -1072,10 +1071,10 @@ TEST(FfiTest, AttrsAsDictionary) {
}
TEST(FfiTest, DictionaryAttr) {
CallFrameBuilder::AttributesMap dict0;
AttributesMap dict0;
dict0.try_emplace("i32", 42);
CallFrameBuilder::AttributesMap dict1;
AttributesMap dict1;
dict1.try_emplace("f32", 42.0f);
CallFrameBuilder::AttributesBuilder attrs;
@ -1119,7 +1118,7 @@ TEST(FfiTest, DictionaryAttr) {
}
TEST(FfiTest, StructAttr) {
CallFrameBuilder::AttributesMap dict;
AttributesMap dict;
dict.try_emplace("i32", 42);
dict.try_emplace("f32", 42.0f);
@ -1232,7 +1231,7 @@ TEST(FfiTest, EnumAttr) {
}
TEST(FfiTest, WrongEnumAttrType) {
CallFrameBuilder::AttributesMap dict;
AttributesMap dict;
dict.try_emplace("i32", 42);
CallFrameBuilder::AttributesBuilder attrs;

View File

@ -24,28 +24,26 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Support/LLVM.h"
#include "xla/ffi/call_frame.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
namespace xla::ffi {
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertBoolAttr(
absl::string_view name, mlir::BoolAttr boolean) {
static absl::StatusOr<Attribute> ConvertBoolAttr(absl::string_view name,
mlir::BoolAttr boolean) {
return static_cast<bool>(boolean.getValue());
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertStringAttr(
absl::string_view name, mlir::StringAttr str) {
static absl::StatusOr<Attribute> ConvertStringAttr(absl::string_view name,
mlir::StringAttr str) {
return str.getValue().str();
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
absl::string_view name, mlir::IntegerAttr integer) {
static absl::StatusOr<Attribute> ConvertIntegerAttr(absl::string_view name,
mlir::IntegerAttr integer) {
if (integer.getType().isUnsignedInteger()) {
switch (integer.getType().getIntOrFloatBitWidth()) {
case 8:
@ -77,8 +75,8 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
}
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
absl::string_view name, mlir::FloatAttr fp) {
static absl::StatusOr<Attribute> ConvertFloatAttr(absl::string_view name,
mlir::FloatAttr fp) {
switch (fp.getType().getIntOrFloatBitWidth()) {
case 32:
return static_cast<float>(fp.getValue().convertToFloat());
@ -90,24 +88,28 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
}
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertArrayAttr(
absl::string_view name, mlir::DenseArrayAttr arr) {
static absl::StatusOr<Attribute> ConvertArrayAttr(absl::string_view name,
mlir::DenseArrayAttr arr) {
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unsupported array element type for attribute: ", name));
}
if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
}
if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
}
if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
}
if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
}
if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
return dense.asArrayRef().vec();
}
return absl::InvalidArgumentError(
absl::StrCat("Unsupported array element type for attribute: ", name));
}
template <typename T>
@ -117,7 +119,7 @@ static std::vector<T> CopyDenseElementsToVec(
return std::vector<T>(it.begin(), it.end());
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDenseElementsAttr(
static absl::StatusOr<Attribute> ConvertDenseElementsAttr(
absl::string_view name, mlir::DenseIntOrFPElementsAttr arr) {
auto type = arr.getElementType();
if (type.isInteger()) {
@ -156,16 +158,15 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDenseElementsAttr(
absl::StrCat("Unsupported array element type for attribute: ", name));
}
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDictionaryAttr(
static absl::StatusOr<Attribute> ConvertDictionaryAttr(
absl::string_view name, mlir::DictionaryAttr dict) {
TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict));
return CallFrameBuilder::Dictionary{
std::make_shared<CallFrameBuilder::AttributesMap>(std::move(attrs))};
return AttributesDictionary{
std::make_shared<AttributesMap>(std::move(attrs))};
}
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict) {
CallFrameBuilder::AttributesMap attributes;
absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict) {
AttributesMap attributes;
for (auto& kv : dict) {
absl::string_view name = kv.getName().strref();
mlir::Attribute value = kv.getValue();

View File

@ -16,16 +16,65 @@ limitations under the License.
#ifndef XLA_FFI_ATTRIBUTE_MAP_H_
#define XLA_FFI_ATTRIBUTE_MAP_H_
#include <cstdint>
#include <memory>
#include <string>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "xla/ffi/call_frame.h"
namespace xla::ffi {
namespace internal {
// A little bit of template metaprogramming to append type to std::variant.
template <typename V, class T>
struct AppendType;
template <typename... Ts, class T>
struct AppendType<std::variant<Ts...>, T> {
using Type = std::variant<Ts..., T>;
};
} // namespace internal
// A single scalar value.
using Scalar = std::variant<bool, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, float, double>;
// An array of elements of the same Scalar type.
using Array = std::variant<std::vector<int8_t>, std::vector<int16_t>,
std::vector<int32_t>, std::vector<int64_t>,
std::vector<uint8_t>, std::vector<uint16_t>,
std::vector<uint32_t>, std::vector<uint64_t>,
std::vector<float>, std::vector<double>>;
// Attributes that do not support nested dictionaries.
using FlatAttribute = std::variant<Scalar, Array, std::string>;
// A map that maps from an arbitrary name (string key) to a flat attribute.
using FlatAttributesMap = absl::flat_hash_map<std::string, FlatAttribute>;
// Forward declaration of the recursive type.
struct AttributesDictionary;
// Attributes that support arbitrary nesting.
using Attribute =
internal::AppendType<FlatAttribute, AttributesDictionary>::Type;
// AttributesMap is a map from an arbitrary name (string key) to an attribute.
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
// Dictionary is just a wrapper around `AttributesMap`. We need an indirection
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
// use shared pointer to keep `AttributesMap` copyable.
struct AttributesDictionary {
std::shared_ptr<AttributesMap> attrs;
};
// Converts MLIR dictionary attribute attached to a custom call operation to a
// custom call handler attributes that are forwarded to the FFI handler.
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
mlir::DictionaryAttr dict);
absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict);
} // namespace xla::ffi

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "xla/ffi/api/api.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
#include "xla/ffi/attribute_map.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"
@ -51,7 +52,7 @@ struct CallFrameBuilder::Buffer {
absl::InlinedVector<int64_t, 4> dims;
};
CallFrameBuilder::AttributesMap CallFrameBuilder::AttributesBuilder::Build() {
AttributesMap CallFrameBuilder::AttributesBuilder::Build() {
return std::move(attrs_);
}
@ -65,8 +66,9 @@ void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
AttributesMap attrs) {
attrs_.try_emplace(std::move(name),
Dictionary{std::make_shared<AttributesMap>(attrs)});
attrs_.try_emplace(
std::move(name),
AttributesDictionary{std::make_shared<AttributesMap>(attrs)});
}
void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) {
@ -160,13 +162,13 @@ struct CallFrame::Dictionary {
};
struct CallFrame::Array {
CallFrameBuilder::Array value; // XLA_FFI_Array::data
xla::ffi::Array value; // XLA_FFI_Array::data
XLA_FFI_Array array = {};
};
struct CallFrame::Scalar {
CallFrameBuilder::Scalar value; // XLA_FFI_Scalar::value
xla::ffi::Scalar value; // XLA_FFI_Scalar::value
XLA_FFI_Scalar scalar = {};
};
@ -413,11 +415,11 @@ std::unique_ptr<CallFrame::Results> CallFrame::FixUpRets(
// An std::visit overload set for converting CallFrameBuilder::Attribute to
// CallFrame::Attribute.
struct CallFrame::ConvertAttribute {
CallFrame::Attribute operator()(const CallFrameBuilder::Array& array) {
CallFrame::Attribute operator()(const xla::ffi::Array& array) {
return CallFrame::Array{array};
}
CallFrame::Attribute operator()(const CallFrameBuilder::Scalar& scalar) {
CallFrame::Attribute operator()(const xla::ffi::Scalar& scalar) {
return CallFrame::Scalar{scalar};
}
@ -425,8 +427,8 @@ struct CallFrame::ConvertAttribute {
return CallFrame::String{str};
}
CallFrame::Attribute operator()(const CallFrameBuilder::Dictionary& dict) {
return CallFrame::Dictionary{CreateAttrs(*dict.attrs)};
CallFrame::Attribute operator()(const xla::ffi::AttributesDictionary& dict) {
return Dictionary{CreateAttrs(*dict.attrs)};
}
};
@ -498,7 +500,7 @@ struct CallFrame::AttributeStorage {
};
std::unique_ptr<CallFrame::Attributes> CallFrame::CreateAttrs(
const CallFrameBuilder::AttributesMap& battrs) {
const xla::ffi::AttributesMap& battrs) {
auto attrs = std::make_unique<Attributes>();
// Convert call frame builder attributes to a collection of named attributes.

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/types.h" // IWYU pragma: keep
#include "xla/xla_data.pb.h"
@ -45,15 +46,6 @@ namespace xla::ffi {
class CallFrame; // forward declare
class CallFrameBuilder {
// A little bit of template metaprogramming to append type to std::variant.
template <typename V, class T>
struct AppendType;
template <typename... Ts, class T>
struct AppendType<std::variant<Ts...>, T> {
using Type = std::variant<Ts..., T>;
};
public:
CallFrameBuilder(size_t num_args, size_t num_rets);
~CallFrameBuilder();
@ -61,32 +53,6 @@ class CallFrameBuilder {
CallFrameBuilder(CallFrameBuilder&&);
CallFrameBuilder& operator=(CallFrameBuilder&&);
using Scalar = std::variant<bool, int8_t, int16_t, int32_t, int64_t, uint8_t,
uint16_t, uint32_t, uint64_t, float, double>;
using Array = std::variant<std::vector<int8_t>, std::vector<int16_t>,
std::vector<int32_t>, std::vector<int64_t>,
std::vector<uint8_t>, std::vector<uint16_t>,
std::vector<uint32_t>, std::vector<uint64_t>,
std::vector<float>, std::vector<double>>;
// Declare implementation detail structs for call frame builder storage.
struct Dictionary;
// Attributes that do not support nested dictionaries.
using FlatAttribute = std::variant<Scalar, Array, std::string>;
using FlatAttributesMap = absl::flat_hash_map<std::string, FlatAttribute>;
// Attributes that support arbitrary nesting.
using Attribute = typename AppendType<FlatAttribute, Dictionary>::Type;
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
// Dictionary is just a wrapper around AttributesMap. We need an indirection
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
// use shared pointer to keep `AttributesMap` copyable.
struct Dictionary {
std::shared_ptr<AttributesMap> attrs;
};
// A helper class to build call frame attributes.
class AttributesBuilder {
public:
@ -224,8 +190,7 @@ class CallFrame {
//===----- Call frame attributes ----------------------------------------===//
// Creates call frame attributes from the call frame builder attributes.
static std::unique_ptr<Attributes> CreateAttrs(
const CallFrameBuilder::AttributesMap& attrs);
static std::unique_ptr<Attributes> CreateAttrs(const AttributesMap& attrs);
// Fixes up call frame attributes by initializing XLA FFI structs with valid
// pointers into storage objects.

View File

@ -24,6 +24,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/strings/str_cat.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/tsl/platform/test.h"
@ -131,7 +132,7 @@ void BM_AddBufferArg(benchmark::State& state) {
void BM_AddAttributes(benchmark::State& state) {
size_t num_attrs = state.range(0);
CallFrameBuilder::AttributesMap attrs;
AttributesMap attrs;
for (size_t i = 0; i < num_attrs; ++i) {
attrs.try_emplace(absl::StrCat("attr_", i), 42);
}

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "xla/executable_run_options.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/attribute_map.h"
#include "xla/ffi/call_frame.h"
#include "xla/ffi/execution_context.h"
#include "xla/ffi/execution_state.h"
@ -412,10 +413,10 @@ TEST(FfiTest, AttrsAsDictionary) {
}
TEST(FfiTest, DictionaryAttr) {
CallFrameBuilder::AttributesMap dict0;
AttributesMap dict0;
dict0.try_emplace("i32", 42);
CallFrameBuilder::AttributesMap dict1;
AttributesMap dict1;
dict1.try_emplace("f32", 42.0f);
CallFrameBuilder::AttributesBuilder attrs;
@ -458,7 +459,7 @@ TEST(FfiTest, DictionaryAttr) {
}
TEST(FfiTest, StructAttr) {
CallFrameBuilder::AttributesMap dict;
AttributesMap dict;
dict.try_emplace("i32", 42);
dict.try_emplace("f32", 42.0f);

View File

@ -1199,7 +1199,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
// attributes map at IR emission time, so that we do not need to
// parse MLIR at run time. For FFI handlers backend config must be
// a compatible MLIR dictionary.
CustomCallThunk::AttributesMap attributes;
ffi::AttributesMap attributes;
auto backend_config = instr->backend_config<GpuBackendConfig>();
if (!backend_config.ok()) {