mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Move Attribute types from call_frame.cc into attribute_map.cc
This is moving `Scalar`, `Array`, `Dictionary`, `FlatAttribute`, `FlatAttributeMap`, and `AttributeMap` from `CallFrameBuilder` into the `xla::ffi` namespace.
It also moves the code into `attribute_map.{cc|h}`.
All these types are basically aliases from some kind of `std::variant` type. This change is a preparation for making them proper types and add `ToProto` and `FromProto` methods.
PiperOrigin-RevId: 824435281
This commit is contained in:
parent
78a0ca0b60
commit
76a084f181
|
|
@ -60,7 +60,7 @@ limitations under the License.
|
|||
|
||||
namespace xla::cpu {
|
||||
|
||||
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;
|
||||
using AttributesMap = ffi::AttributesMap;
|
||||
|
||||
static absl::StatusOr<AttributesMap> ParseAttributes(
|
||||
absl::string_view backend_config) {
|
||||
|
|
|
|||
|
|
@ -878,7 +878,7 @@ absl::StatusOr<FusionEmissionResult> EmitCustomCall(
|
|||
// For XLA FFI handlers we decode opaque backend config into attributes map
|
||||
// at IR emission time, so that we do not need to parse MLIR at run time.
|
||||
// For FFI handlers backend config must be a compatible MLIR dictionary.
|
||||
CustomCallThunk::AttributesMap attributes;
|
||||
ffi::AttributesMap attributes;
|
||||
|
||||
// For information about this calling convention, see
|
||||
// xla/g3doc/custom_call.md.
|
||||
|
|
|
|||
|
|
@ -706,6 +706,7 @@ cc_library(
|
|||
"//xla:executable_run_options",
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla/ffi:attribute_map",
|
||||
"//xla/ffi:call_frame",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:execution_state",
|
||||
|
|
|
|||
|
|
@ -989,7 +989,7 @@ class CuDnnCmd : public TracedCommandBufferCmd {
|
|||
class CustomCallCmd : public CommandBufferCmd {
|
||||
public:
|
||||
using CustomCallTarget = CustomCallThunk::CustomCallTarget;
|
||||
using AttributesMap = CustomCallThunk::AttributesMap;
|
||||
using AttributesMap = ffi::AttributesMap;
|
||||
|
||||
// This is a legacy custom call API that is discouraged, and will be
|
||||
// deprecated once XLA:FFI mechanism is ready.
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
#include "xla/ffi/ffi_api.h"
|
||||
|
|
@ -69,7 +70,7 @@ using xla::ffi::CallOptions;
|
|||
static absl::StatusOr<ffi::CallFrame> BuildCallFramePrototype(
|
||||
absl::Span<const std::optional<ShapedSlice>> operands,
|
||||
absl::Span<const std::optional<ShapedSlice>> results,
|
||||
CustomCallThunk::AttributesMap attributes) {
|
||||
ffi::AttributesMap attributes) {
|
||||
CallFrameBuilder builder(
|
||||
/*num_args=*/operands.size(),
|
||||
/*num_rets=*/results.size());
|
||||
|
|
@ -204,8 +205,9 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
|
|||
absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
|
||||
ThunkInfo thunk_info, std::string target_name,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
|
||||
const HloComputation* called_computation, absl::string_view platform_name) {
|
||||
std::vector<std::optional<ShapedSlice>> results,
|
||||
ffi::AttributesMap attributes, const HloComputation* called_computation,
|
||||
absl::string_view platform_name) {
|
||||
TF_ASSIGN_OR_RETURN(ffi::HandlerRegistration registration,
|
||||
ffi::FindHandler(target_name, platform_name));
|
||||
|
||||
|
|
@ -218,8 +220,8 @@ absl::StatusOr<std::unique_ptr<CustomCallThunk>> CustomCallThunk::Create(
|
|||
ThunkInfo thunk_info, std::string target_name,
|
||||
XLA_FFI_Handler_Bundle bundle,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
|
||||
const HloComputation* called_computation) {
|
||||
std::vector<std::optional<ShapedSlice>> results,
|
||||
ffi::AttributesMap attributes, const HloComputation* called_computation) {
|
||||
auto execution_state = std::make_unique<ffi::ExecutionState>();
|
||||
|
||||
// Initialize FFI handler state if it has an instantiate callback.
|
||||
|
|
@ -267,7 +269,7 @@ CustomCallThunk::CustomCallThunk(
|
|||
XLA_FFI_Handler_Bundle bundle,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results, CallFrame call_frame,
|
||||
AttributesMap attributes,
|
||||
ffi::AttributesMap attributes,
|
||||
std::unique_ptr<ffi::ExecutionState> execution_state,
|
||||
const HloComputation* called_computation)
|
||||
: Thunk(Thunk::kCustomCall, thunk_info),
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/runtime/thunk.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
|
|
@ -59,9 +60,6 @@ class CustomCallThunk : public Thunk {
|
|||
std::function<void(stream_executor::Stream*, void**, const char*, size_t,
|
||||
XlaCustomCallStatus*)>;
|
||||
|
||||
using Attribute = ffi::CallFrameBuilder::Attribute;
|
||||
using AttributesMap = ffi::CallFrameBuilder::AttributesMap;
|
||||
|
||||
// Creates a serializable custom call thunk. The callback is resolved using
|
||||
// the legacy CustomCall registry. For new code please use XLA FFI instead.
|
||||
static absl::StatusOr<std::unique_ptr<CustomCallThunk>> Create(
|
||||
|
|
@ -84,7 +82,8 @@ class CustomCallThunk : public Thunk {
|
|||
static absl::StatusOr<std::unique_ptr<CustomCallThunk>> Create(
|
||||
ThunkInfo thunk_info, std::string target_name,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
|
||||
std::vector<std::optional<ShapedSlice>> results,
|
||||
xla::ffi::AttributesMap attributes,
|
||||
const HloComputation* called_computation,
|
||||
absl::string_view platform_name);
|
||||
|
||||
|
|
@ -95,7 +94,8 @@ class CustomCallThunk : public Thunk {
|
|||
ThunkInfo thunk_info, std::string target_name,
|
||||
XLA_FFI_Handler_Bundle bundle,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results, AttributesMap attributes,
|
||||
std::vector<std::optional<ShapedSlice>> results,
|
||||
xla::ffi::AttributesMap attributes,
|
||||
const HloComputation* called_computation);
|
||||
|
||||
absl::Status Prepare(const PrepareParams& params,
|
||||
|
|
@ -130,7 +130,7 @@ class CustomCallThunk : public Thunk {
|
|||
XLA_FFI_Handler_Bundle bundle,
|
||||
std::vector<std::optional<ShapedSlice>> operands,
|
||||
std::vector<std::optional<ShapedSlice>> results,
|
||||
ffi::CallFrame call_frame, AttributesMap attributes,
|
||||
ffi::CallFrame call_frame, xla::ffi::AttributesMap attributes,
|
||||
std::unique_ptr<ffi::ExecutionState> execution_state,
|
||||
const HloComputation* called_computation);
|
||||
|
||||
|
|
@ -160,7 +160,7 @@ class CustomCallThunk : public Thunk {
|
|||
// functions with XLA runtime. It's under construction, and still misses
|
||||
// a lot of features. Long term it will replace legacy custom calls.
|
||||
std::optional<XLA_FFI_Handler_Bundle> bundle_;
|
||||
std::optional<AttributesMap> attributes_;
|
||||
std::optional<xla::ffi::AttributesMap> attributes_;
|
||||
|
||||
// Reference call frame pre-initialized at construction time.
|
||||
std::optional<ffi::CallFrame> call_frame_;
|
||||
|
|
|
|||
|
|
@ -570,7 +570,7 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpy) {
|
|||
seq.emplace_back(),
|
||||
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
|
||||
registration.bundle, operands, results,
|
||||
/*attributes=*/CustomCallThunk::AttributesMap(),
|
||||
/*attributes=*/ffi::AttributesMap(),
|
||||
/*called_computation=*/nullptr));
|
||||
|
||||
// Wrapping dynamic slice thunk around the custom call thunk.
|
||||
|
|
@ -731,7 +731,7 @@ TEST_F(DynamicSliceThunkTest, SlicedOutputMemcpy) {
|
|||
seq.emplace_back(),
|
||||
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
|
||||
registration.bundle, operands, results,
|
||||
/*attributes=*/CustomCallThunk::AttributesMap(),
|
||||
/*attributes=*/ffi::AttributesMap(),
|
||||
/*called_computation=*/nullptr));
|
||||
|
||||
// Wrapping dynamic slice thunk around the custom call thunk.
|
||||
|
|
@ -1451,7 +1451,7 @@ TEST_F(DynamicSliceThunkTest, SlicedMemcpyOOB) {
|
|||
seq.emplace_back(),
|
||||
CustomCallThunk::Create(Thunk::ThunkInfo(), "__xla_test$$memcpy",
|
||||
registration.bundle, operands, results,
|
||||
/*attributes=*/CustomCallThunk::AttributesMap(),
|
||||
/*attributes=*/ffi::AttributesMap(),
|
||||
/*called_computation=*/nullptr));
|
||||
|
||||
// Wrapping dynamic slice thunk around the custom call thunk.
|
||||
|
|
|
|||
9
third_party/xla/xla/ffi/BUILD
vendored
9
third_party/xla/xla/ffi/BUILD
vendored
|
|
@ -22,6 +22,7 @@ cc_library(
|
|||
hdrs = ["call_frame.h"],
|
||||
deps = [
|
||||
":api",
|
||||
":attribute_map",
|
||||
"//xla:types",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
|
|
@ -44,6 +45,7 @@ xla_cc_test(
|
|||
name = "call_frame_test",
|
||||
srcs = ["call_frame_test.cc"],
|
||||
deps = [
|
||||
":attribute_map",
|
||||
":call_frame",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/ffi/api:c_api",
|
||||
|
|
@ -196,7 +198,9 @@ cc_library(
|
|||
srcs = ["attribute_map.cc"],
|
||||
hdrs = ["attribute_map.h"],
|
||||
deps = [
|
||||
":call_frame",
|
||||
"//xla/tsl/platform:errors",
|
||||
"//xla/tsl/platform:statusor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
|
@ -204,8 +208,6 @@ cc_library(
|
|||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@local_tsl//tsl/platform:errors",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -216,6 +218,7 @@ xla_cc_test(
|
|||
features = ["-use_header_modules"],
|
||||
shuffle_tests = False,
|
||||
deps = [
|
||||
":attribute_map",
|
||||
":call_frame",
|
||||
":execution_context",
|
||||
":execution_state",
|
||||
|
|
|
|||
1
third_party/xla/xla/ffi/api/BUILD
vendored
1
third_party/xla/xla/ffi/api/BUILD
vendored
|
|
@ -85,6 +85,7 @@ xla_cc_test(
|
|||
"//xla:executable_run_options",
|
||||
"//xla:shape_util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/ffi:attribute_map",
|
||||
"//xla/ffi:call_frame",
|
||||
"//xla/ffi:execution_context",
|
||||
"//xla/ffi:execution_state",
|
||||
|
|
|
|||
33
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
33
third_party/xla/xla/ffi/api/ffi_test.cc
vendored
|
|
@ -38,6 +38,7 @@ limitations under the License.
|
|||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
|
|
@ -117,6 +118,7 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(
|
|||
|
||||
namespace xla::ffi {
|
||||
|
||||
using ::absl_testing::StatusIs;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
TEST(FfiTest, DataTypeEnumValue) {
|
||||
|
|
@ -577,9 +579,8 @@ TEST(FfiTest, MissingBufferArgument) {
|
|||
[](auto) { return Error::Success(); });
|
||||
auto status = Call(*handler, call_frame);
|
||||
|
||||
EXPECT_THAT(status,
|
||||
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong number of arguments")));
|
||||
EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong number of arguments")));
|
||||
}
|
||||
|
||||
TEST(FfiTest, WrongRankBufferArgument) {
|
||||
|
|
@ -595,9 +596,8 @@ TEST(FfiTest, WrongRankBufferArgument) {
|
|||
auto status = Call(*handler, call_frame);
|
||||
|
||||
EXPECT_THAT(status,
|
||||
absl_testing::StatusIs(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong buffer rank: expected 1 but got 2")));
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong buffer rank: expected 1 but got 2")));
|
||||
}
|
||||
|
||||
TEST(FfiTest, WrongTypeBufferArgument) {
|
||||
|
|
@ -612,10 +612,10 @@ TEST(FfiTest, WrongTypeBufferArgument) {
|
|||
[](auto) { return Error::Success(); });
|
||||
auto status = Call(*handler, call_frame);
|
||||
|
||||
EXPECT_THAT(status,
|
||||
absl_testing::StatusIs(
|
||||
absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong buffer dtype: expected F32 but got S32")));
|
||||
EXPECT_THAT(
|
||||
status,
|
||||
StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong buffer dtype: expected F32 but got S32")));
|
||||
}
|
||||
|
||||
TEST(FfiTest, WrongNumberOfArguments) {
|
||||
|
|
@ -631,9 +631,8 @@ TEST(FfiTest, WrongNumberOfArguments) {
|
|||
Ffi::Bind().Attr<int>("foo").To([](int foo) { return Error::Success(); });
|
||||
auto status = Call(*handler, call_frame);
|
||||
|
||||
EXPECT_THAT(status,
|
||||
absl_testing::StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong number of attributes")));
|
||||
EXPECT_THAT(status, StatusIs(absl::StatusCode::kInvalidArgument,
|
||||
HasSubstr("Wrong number of attributes")));
|
||||
EXPECT_THAT(status.message(), HasSubstr("foo"));
|
||||
EXPECT_THAT(status.message(), HasSubstr("bar"));
|
||||
}
|
||||
|
|
@ -1072,10 +1071,10 @@ TEST(FfiTest, AttrsAsDictionary) {
|
|||
}
|
||||
|
||||
TEST(FfiTest, DictionaryAttr) {
|
||||
CallFrameBuilder::AttributesMap dict0;
|
||||
AttributesMap dict0;
|
||||
dict0.try_emplace("i32", 42);
|
||||
|
||||
CallFrameBuilder::AttributesMap dict1;
|
||||
AttributesMap dict1;
|
||||
dict1.try_emplace("f32", 42.0f);
|
||||
|
||||
CallFrameBuilder::AttributesBuilder attrs;
|
||||
|
|
@ -1119,7 +1118,7 @@ TEST(FfiTest, DictionaryAttr) {
|
|||
}
|
||||
|
||||
TEST(FfiTest, StructAttr) {
|
||||
CallFrameBuilder::AttributesMap dict;
|
||||
AttributesMap dict;
|
||||
dict.try_emplace("i32", 42);
|
||||
dict.try_emplace("f32", 42.0f);
|
||||
|
||||
|
|
@ -1232,7 +1231,7 @@ TEST(FfiTest, EnumAttr) {
|
|||
}
|
||||
|
||||
TEST(FfiTest, WrongEnumAttrType) {
|
||||
CallFrameBuilder::AttributesMap dict;
|
||||
AttributesMap dict;
|
||||
dict.try_emplace("i32", 42);
|
||||
|
||||
CallFrameBuilder::AttributesBuilder attrs;
|
||||
|
|
|
|||
71
third_party/xla/xla/ffi/attribute_map.cc
vendored
71
third_party/xla/xla/ffi/attribute_map.cc
vendored
|
|
@ -24,28 +24,26 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/ADT/TypeSwitch.h" // IWYU pragma: keep
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
#include "tsl/platform/errors.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/tsl/platform/statusor.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertBoolAttr(
|
||||
absl::string_view name, mlir::BoolAttr boolean) {
|
||||
static absl::StatusOr<Attribute> ConvertBoolAttr(absl::string_view name,
|
||||
mlir::BoolAttr boolean) {
|
||||
return static_cast<bool>(boolean.getValue());
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertStringAttr(
|
||||
absl::string_view name, mlir::StringAttr str) {
|
||||
static absl::StatusOr<Attribute> ConvertStringAttr(absl::string_view name,
|
||||
mlir::StringAttr str) {
|
||||
return str.getValue().str();
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
|
||||
absl::string_view name, mlir::IntegerAttr integer) {
|
||||
static absl::StatusOr<Attribute> ConvertIntegerAttr(absl::string_view name,
|
||||
mlir::IntegerAttr integer) {
|
||||
if (integer.getType().isUnsignedInteger()) {
|
||||
switch (integer.getType().getIntOrFloatBitWidth()) {
|
||||
case 8:
|
||||
|
|
@ -77,8 +75,8 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertIntegerAttr(
|
|||
}
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
|
||||
absl::string_view name, mlir::FloatAttr fp) {
|
||||
static absl::StatusOr<Attribute> ConvertFloatAttr(absl::string_view name,
|
||||
mlir::FloatAttr fp) {
|
||||
switch (fp.getType().getIntOrFloatBitWidth()) {
|
||||
case 32:
|
||||
return static_cast<float>(fp.getValue().convertToFloat());
|
||||
|
|
@ -90,24 +88,28 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertFloatAttr(
|
|||
}
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertArrayAttr(
|
||||
absl::string_view name, mlir::DenseArrayAttr arr) {
|
||||
static absl::StatusOr<Attribute> ConvertArrayAttr(absl::string_view name,
|
||||
mlir::DenseArrayAttr arr) {
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseI8ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported array element type for attribute: ", name));
|
||||
}
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseI16ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
}
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseI32ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
}
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseI64ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
}
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseF32ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
}
|
||||
if (auto dense = mlir::dyn_cast<mlir::DenseF64ArrayAttr>(arr)) {
|
||||
return dense.asArrayRef().vec();
|
||||
}
|
||||
return absl::InvalidArgumentError(
|
||||
absl::StrCat("Unsupported array element type for attribute: ", name));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -117,7 +119,7 @@ static std::vector<T> CopyDenseElementsToVec(
|
|||
return std::vector<T>(it.begin(), it.end());
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDenseElementsAttr(
|
||||
static absl::StatusOr<Attribute> ConvertDenseElementsAttr(
|
||||
absl::string_view name, mlir::DenseIntOrFPElementsAttr arr) {
|
||||
auto type = arr.getElementType();
|
||||
if (type.isInteger()) {
|
||||
|
|
@ -156,16 +158,15 @@ static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDenseElementsAttr(
|
|||
absl::StrCat("Unsupported array element type for attribute: ", name));
|
||||
}
|
||||
|
||||
static absl::StatusOr<CallFrameBuilder::Attribute> ConvertDictionaryAttr(
|
||||
static absl::StatusOr<Attribute> ConvertDictionaryAttr(
|
||||
absl::string_view name, mlir::DictionaryAttr dict) {
|
||||
TF_ASSIGN_OR_RETURN(auto attrs, BuildAttributesMap(dict));
|
||||
return CallFrameBuilder::Dictionary{
|
||||
std::make_shared<CallFrameBuilder::AttributesMap>(std::move(attrs))};
|
||||
return AttributesDictionary{
|
||||
std::make_shared<AttributesMap>(std::move(attrs))};
|
||||
}
|
||||
|
||||
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
|
||||
mlir::DictionaryAttr dict) {
|
||||
CallFrameBuilder::AttributesMap attributes;
|
||||
absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict) {
|
||||
AttributesMap attributes;
|
||||
for (auto& kv : dict) {
|
||||
absl::string_view name = kv.getName().strref();
|
||||
mlir::Attribute value = kv.getValue();
|
||||
|
|
|
|||
55
third_party/xla/xla/ffi/attribute_map.h
vendored
55
third_party/xla/xla/ffi/attribute_map.h
vendored
|
|
@ -16,16 +16,65 @@ limitations under the License.
|
|||
#ifndef XLA_FFI_ATTRIBUTE_MAP_H_
|
||||
#define XLA_FFI_ATTRIBUTE_MAP_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
|
||||
namespace xla::ffi {
|
||||
namespace internal {
|
||||
// A little bit of template metaprogramming to append type to std::variant.
|
||||
template <typename V, class T>
|
||||
struct AppendType;
|
||||
|
||||
template <typename... Ts, class T>
|
||||
struct AppendType<std::variant<Ts...>, T> {
|
||||
using Type = std::variant<Ts..., T>;
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
// A single scalar value.
|
||||
using Scalar = std::variant<bool, int8_t, int16_t, int32_t, int64_t, uint8_t,
|
||||
uint16_t, uint32_t, uint64_t, float, double>;
|
||||
|
||||
// An array of elements of the same Scalar type.
|
||||
using Array = std::variant<std::vector<int8_t>, std::vector<int16_t>,
|
||||
std::vector<int32_t>, std::vector<int64_t>,
|
||||
std::vector<uint8_t>, std::vector<uint16_t>,
|
||||
std::vector<uint32_t>, std::vector<uint64_t>,
|
||||
std::vector<float>, std::vector<double>>;
|
||||
|
||||
// Attributes that do not support nested dictionaries.
|
||||
using FlatAttribute = std::variant<Scalar, Array, std::string>;
|
||||
|
||||
// A map that maps from an arbitrary name (string key) to a flat attribute.
|
||||
using FlatAttributesMap = absl::flat_hash_map<std::string, FlatAttribute>;
|
||||
|
||||
// Forward declaration of the recursive type.
|
||||
struct AttributesDictionary;
|
||||
|
||||
// Attributes that support arbitrary nesting.
|
||||
using Attribute =
|
||||
internal::AppendType<FlatAttribute, AttributesDictionary>::Type;
|
||||
|
||||
// AttributesMap is a map from an arbitrary name (string key) to an attribute.
|
||||
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
|
||||
|
||||
// Dictionary is just a wrapper around `AttributesMap`. We need an indirection
|
||||
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
|
||||
// use shared pointer to keep `AttributesMap` copyable.
|
||||
struct AttributesDictionary {
|
||||
std::shared_ptr<AttributesMap> attrs;
|
||||
};
|
||||
|
||||
// Converts MLIR dictionary attribute attached to a custom call operation to a
|
||||
// custom call handler attributes that are forwarded to the FFI handler.
|
||||
absl::StatusOr<CallFrameBuilder::AttributesMap> BuildAttributesMap(
|
||||
mlir::DictionaryAttr dict);
|
||||
absl::StatusOr<AttributesMap> BuildAttributesMap(mlir::DictionaryAttr dict);
|
||||
|
||||
} // namespace xla::ffi
|
||||
|
||||
|
|
|
|||
22
third_party/xla/xla/ffi/call_frame.cc
vendored
22
third_party/xla/xla/ffi/call_frame.cc
vendored
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "xla/ffi/api/api.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/c_api_internal.h" // IWYU pragma: keep
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/platform/errors.h"
|
||||
#include "xla/util.h"
|
||||
|
|
@ -51,7 +52,7 @@ struct CallFrameBuilder::Buffer {
|
|||
absl::InlinedVector<int64_t, 4> dims;
|
||||
};
|
||||
|
||||
CallFrameBuilder::AttributesMap CallFrameBuilder::AttributesBuilder::Build() {
|
||||
AttributesMap CallFrameBuilder::AttributesBuilder::Build() {
|
||||
return std::move(attrs_);
|
||||
}
|
||||
|
||||
|
|
@ -65,8 +66,9 @@ void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
|
|||
|
||||
void CallFrameBuilder::AttributesBuilder::Insert(std::string name,
|
||||
AttributesMap attrs) {
|
||||
attrs_.try_emplace(std::move(name),
|
||||
Dictionary{std::make_shared<AttributesMap>(attrs)});
|
||||
attrs_.try_emplace(
|
||||
std::move(name),
|
||||
AttributesDictionary{std::make_shared<AttributesMap>(attrs)});
|
||||
}
|
||||
|
||||
void CallFrameBuilder::AttributesBuilder::Append(AttributesMap attrs) {
|
||||
|
|
@ -160,13 +162,13 @@ struct CallFrame::Dictionary {
|
|||
};
|
||||
|
||||
struct CallFrame::Array {
|
||||
CallFrameBuilder::Array value; // XLA_FFI_Array::data
|
||||
xla::ffi::Array value; // XLA_FFI_Array::data
|
||||
|
||||
XLA_FFI_Array array = {};
|
||||
};
|
||||
|
||||
struct CallFrame::Scalar {
|
||||
CallFrameBuilder::Scalar value; // XLA_FFI_Scalar::value
|
||||
xla::ffi::Scalar value; // XLA_FFI_Scalar::value
|
||||
|
||||
XLA_FFI_Scalar scalar = {};
|
||||
};
|
||||
|
|
@ -413,11 +415,11 @@ std::unique_ptr<CallFrame::Results> CallFrame::FixUpRets(
|
|||
// An std::visit overload set for converting CallFrameBuilder::Attribute to
|
||||
// CallFrame::Attribute.
|
||||
struct CallFrame::ConvertAttribute {
|
||||
CallFrame::Attribute operator()(const CallFrameBuilder::Array& array) {
|
||||
CallFrame::Attribute operator()(const xla::ffi::Array& array) {
|
||||
return CallFrame::Array{array};
|
||||
}
|
||||
|
||||
CallFrame::Attribute operator()(const CallFrameBuilder::Scalar& scalar) {
|
||||
CallFrame::Attribute operator()(const xla::ffi::Scalar& scalar) {
|
||||
return CallFrame::Scalar{scalar};
|
||||
}
|
||||
|
||||
|
|
@ -425,8 +427,8 @@ struct CallFrame::ConvertAttribute {
|
|||
return CallFrame::String{str};
|
||||
}
|
||||
|
||||
CallFrame::Attribute operator()(const CallFrameBuilder::Dictionary& dict) {
|
||||
return CallFrame::Dictionary{CreateAttrs(*dict.attrs)};
|
||||
CallFrame::Attribute operator()(const xla::ffi::AttributesDictionary& dict) {
|
||||
return Dictionary{CreateAttrs(*dict.attrs)};
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -498,7 +500,7 @@ struct CallFrame::AttributeStorage {
|
|||
};
|
||||
|
||||
std::unique_ptr<CallFrame::Attributes> CallFrame::CreateAttrs(
|
||||
const CallFrameBuilder::AttributesMap& battrs) {
|
||||
const xla::ffi::AttributesMap& battrs) {
|
||||
auto attrs = std::make_unique<Attributes>();
|
||||
|
||||
// Convert call frame builder attributes to a collection of named attributes.
|
||||
|
|
|
|||
39
third_party/xla/xla/ffi/call_frame.h
vendored
39
third_party/xla/xla/ffi/call_frame.h
vendored
|
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||
#include "absl/status/statusor.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/types.h" // IWYU pragma: keep
|
||||
#include "xla/xla_data.pb.h"
|
||||
|
|
@ -45,15 +46,6 @@ namespace xla::ffi {
|
|||
class CallFrame; // forward declare
|
||||
|
||||
class CallFrameBuilder {
|
||||
// A little bit of template metaprogramming to append type to std::variant.
|
||||
template <typename V, class T>
|
||||
struct AppendType;
|
||||
|
||||
template <typename... Ts, class T>
|
||||
struct AppendType<std::variant<Ts...>, T> {
|
||||
using Type = std::variant<Ts..., T>;
|
||||
};
|
||||
|
||||
public:
|
||||
CallFrameBuilder(size_t num_args, size_t num_rets);
|
||||
~CallFrameBuilder();
|
||||
|
|
@ -61,32 +53,6 @@ class CallFrameBuilder {
|
|||
CallFrameBuilder(CallFrameBuilder&&);
|
||||
CallFrameBuilder& operator=(CallFrameBuilder&&);
|
||||
|
||||
using Scalar = std::variant<bool, int8_t, int16_t, int32_t, int64_t, uint8_t,
|
||||
uint16_t, uint32_t, uint64_t, float, double>;
|
||||
using Array = std::variant<std::vector<int8_t>, std::vector<int16_t>,
|
||||
std::vector<int32_t>, std::vector<int64_t>,
|
||||
std::vector<uint8_t>, std::vector<uint16_t>,
|
||||
std::vector<uint32_t>, std::vector<uint64_t>,
|
||||
std::vector<float>, std::vector<double>>;
|
||||
|
||||
// Declare implementation detail structs for call frame builder storage.
|
||||
struct Dictionary;
|
||||
|
||||
// Attributes that do not support nested dictionaries.
|
||||
using FlatAttribute = std::variant<Scalar, Array, std::string>;
|
||||
using FlatAttributesMap = absl::flat_hash_map<std::string, FlatAttribute>;
|
||||
|
||||
// Attributes that support arbitrary nesting.
|
||||
using Attribute = typename AppendType<FlatAttribute, Dictionary>::Type;
|
||||
using AttributesMap = absl::flat_hash_map<std::string, Attribute>;
|
||||
|
||||
// Dictionary is just a wrapper around AttributesMap. We need an indirection
|
||||
// through `std::shared_ptr` to be able to define recursive `std::variant`. We
|
||||
// use shared pointer to keep `AttributesMap` copyable.
|
||||
struct Dictionary {
|
||||
std::shared_ptr<AttributesMap> attrs;
|
||||
};
|
||||
|
||||
// A helper class to build call frame attributes.
|
||||
class AttributesBuilder {
|
||||
public:
|
||||
|
|
@ -224,8 +190,7 @@ class CallFrame {
|
|||
//===----- Call frame attributes ----------------------------------------===//
|
||||
|
||||
// Creates call frame attributes from the call frame builder attributes.
|
||||
static std::unique_ptr<Attributes> CreateAttrs(
|
||||
const CallFrameBuilder::AttributesMap& attrs);
|
||||
static std::unique_ptr<Attributes> CreateAttrs(const AttributesMap& attrs);
|
||||
|
||||
// Fixes up call frame attributes by initializing XLA FFI structs with valid
|
||||
// pointers into storage objects.
|
||||
|
|
|
|||
3
third_party/xla/xla/ffi/call_frame_test.cc
vendored
3
third_party/xla/xla/ffi/call_frame_test.cc
vendored
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include <gtest/gtest.h>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/stream_executor/device_memory.h"
|
||||
#include "xla/tsl/lib/core/status_test_util.h"
|
||||
#include "xla/tsl/platform/test.h"
|
||||
|
|
@ -131,7 +132,7 @@ void BM_AddBufferArg(benchmark::State& state) {
|
|||
void BM_AddAttributes(benchmark::State& state) {
|
||||
size_t num_attrs = state.range(0);
|
||||
|
||||
CallFrameBuilder::AttributesMap attrs;
|
||||
AttributesMap attrs;
|
||||
for (size_t i = 0; i < num_attrs; ++i) {
|
||||
attrs.try_emplace(absl::StrCat("attr_", i), 42);
|
||||
}
|
||||
|
|
|
|||
7
third_party/xla/xla/ffi/ffi_test.cc
vendored
7
third_party/xla/xla/ffi/ffi_test.cc
vendored
|
|
@ -35,6 +35,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "xla/executable_run_options.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/attribute_map.h"
|
||||
#include "xla/ffi/call_frame.h"
|
||||
#include "xla/ffi/execution_context.h"
|
||||
#include "xla/ffi/execution_state.h"
|
||||
|
|
@ -412,10 +413,10 @@ TEST(FfiTest, AttrsAsDictionary) {
|
|||
}
|
||||
|
||||
TEST(FfiTest, DictionaryAttr) {
|
||||
CallFrameBuilder::AttributesMap dict0;
|
||||
AttributesMap dict0;
|
||||
dict0.try_emplace("i32", 42);
|
||||
|
||||
CallFrameBuilder::AttributesMap dict1;
|
||||
AttributesMap dict1;
|
||||
dict1.try_emplace("f32", 42.0f);
|
||||
|
||||
CallFrameBuilder::AttributesBuilder attrs;
|
||||
|
|
@ -458,7 +459,7 @@ TEST(FfiTest, DictionaryAttr) {
|
|||
}
|
||||
|
||||
TEST(FfiTest, StructAttr) {
|
||||
CallFrameBuilder::AttributesMap dict;
|
||||
AttributesMap dict;
|
||||
dict.try_emplace("i32", 42);
|
||||
dict.try_emplace("f32", 42.0f);
|
||||
|
||||
|
|
|
|||
|
|
@ -1199,7 +1199,7 @@ absl::Status IrEmitterUnnested::EmitCustomCallThunk(
|
|||
// attributes map at IR emission time, so that we do not need to
|
||||
// parse MLIR at run time. For FFI handlers backend config must be
|
||||
// a compatible MLIR dictionary.
|
||||
CustomCallThunk::AttributesMap attributes;
|
||||
ffi::AttributesMap attributes;
|
||||
|
||||
auto backend_config = instr->backend_config<GpuBackendConfig>();
|
||||
if (!backend_config.ok()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user