mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA][XTile] Add xtile lowering passes for triton.
This enables migrating the triton emitter to use emit xtile entry, insert & extract in the child PR. The main difference is the memref args in the entry function for which `MemrefToPtr` & `PtrToMemref` were introduced which closely resemble `UnrealizedConversionCastOp` with additional verification and will enable special folding of `memref::TransposeOp`. PiperOrigin-RevId: 821593545
This commit is contained in:
parent
ea72bd7e48
commit
beb48d90e2
|
|
@ -239,6 +239,7 @@ cc_library(
|
|||
"//xla/codegen/tiling:tiled_hlo_computation",
|
||||
"//xla/codegen/tiling:tiled_hlo_fusion_instruction",
|
||||
"//xla/codegen/tiling:tiled_hlo_instruction",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:hlo_module_config",
|
||||
"//xla/service/gpu/model:block_level_parameters",
|
||||
|
|
@ -353,8 +354,9 @@ cc_library(
|
|||
"//xla:status_macros",
|
||||
"//xla:util",
|
||||
"//xla:xla_data_proto_cc",
|
||||
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
|
||||
"//xla/codegen:emitter_loc_op_builder",
|
||||
"//xla/codegen/tiling:tiled_hlo_computation",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/utils:hlo_query",
|
||||
"//xla/hlo/utils:hlo_traversal",
|
||||
|
|
@ -392,6 +394,7 @@ cc_library(
|
|||
"@llvm-project//mlir:FunctionInterfaces",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MathDialect",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@triton//:TritonDialects",
|
||||
|
|
@ -443,6 +446,7 @@ cc_library(
|
|||
"//xla:autotuning_proto_cc",
|
||||
"//xla/codegen:emitter_loc_op_builder",
|
||||
"//xla/codegen/tiling:symbolic_tile_analysis",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/utils:hlo_traversal",
|
||||
"//xla/service:hlo_module_config",
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "llvm/Support/MathExtras.h"
|
||||
#include "llvm/TargetParser/Triple.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Math/IR/Math.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
|
@ -66,6 +67,7 @@ limitations under the License.
|
|||
#include "xla/xla.pb.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
namespace xla::gpu::triton {
|
||||
|
||||
|
|
@ -616,4 +618,20 @@ absl::StatusOr<stream_executor::gpu::TmaMetadata> ExtractTmaMetadata(
|
|||
return tma_metadata;
|
||||
}
|
||||
|
||||
::mlir::triton::PointerType GetPointerType(mlir::MemRefType memref_type) {
|
||||
int address_space = 0;
|
||||
|
||||
mlir::Attribute memory_space_attr = memref_type.getMemorySpace();
|
||||
if (auto int_memory_space_attr =
|
||||
mlir::dyn_cast_if_present<mlir::IntegerAttr>(memory_space_attr)) {
|
||||
address_space = int_memory_space_attr.getInt();
|
||||
} else if (auto llvm_memory_space_attr = mlir::dyn_cast_if_present<
|
||||
mlir::LLVM::LLVMAddrSpaceAttrInterface>(memory_space_attr)) {
|
||||
address_space = llvm_memory_space_attr.getAddressSpace();
|
||||
}
|
||||
|
||||
return ::mlir::triton::PointerType::get(memref_type.getElementType(),
|
||||
address_space);
|
||||
}
|
||||
|
||||
} // namespace xla::gpu::triton
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ limitations under the License.
|
|||
#include "xla/tsl/platform/status.h"
|
||||
#include "xla/xla.pb.h"
|
||||
#include "xla/xla_data.pb.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
namespace xla::gpu::triton {
|
||||
|
||||
|
|
@ -230,6 +231,10 @@ absl::StatusOr<stream_executor::gpu::TmaMetadata> ExtractTmaMetadata(
|
|||
absl::StatusOr<stream_executor::ThreadDim> ExtractThreadDims(
|
||||
mlir::ModuleOp triton_module, mlir::LLVM::LLVMFuncOp func_op);
|
||||
|
||||
// Returns the triton pointer type that corresponds to the given memref type,
|
||||
// i.e. has the same element type and address space.
|
||||
::mlir::triton::PointerType GetPointerType(mlir::MemRefType memref_type);
|
||||
|
||||
} // namespace xla::gpu::triton
|
||||
|
||||
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_EMITTER_HELPERS_H_
|
||||
|
|
|
|||
|
|
@ -17,3 +17,12 @@ tt.func @xla_triton_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: index) {
|
|||
[%c0, %arg1][16, 64][1, 1] {noinline = false} : tensor<16x64xbf16>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @fold_ptr_memref_ptr(
|
||||
// CHECK-SAME: %[[SRC:.*]]: !tt.ptr<f32>
|
||||
func.func @fold_ptr_memref_ptr(%src: !tt.ptr<f32>) -> !tt.ptr<f32> {
|
||||
// CHECK: return %[[SRC]] : !tt.ptr<f32>
|
||||
%src_ptr = triton_xla.ptr_to_memref %src from !tt.ptr<f32> to memref<256xf32>
|
||||
%dst = triton_xla.memref_to_ptr %src_ptr from memref<256xf32> to !tt.ptr<f32>
|
||||
func.return %dst : !tt.ptr<f32>
|
||||
}
|
||||
|
|
|
|||
|
|
@ -254,6 +254,64 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
results.add<InsertOpOffsetsSizesStridesFolder>(context);
|
||||
}
|
||||
|
||||
OpFoldResult MemrefToPtrOp::fold(FoldAdaptor adaptor) {
|
||||
if (auto ptr_to_memref = getOperand().getDefiningOp<PtrToMemrefOp>()) {
|
||||
// memref_to_ptr(ptr_to_memref(x)) -> x
|
||||
return ptr_to_memref.getOperand();
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
LogicalResult MemrefToPtrOp::verify() {
|
||||
mlir::MemRefType src_type = getSrc().getType();
|
||||
if (src_type.getElementType() != getType().getPointeeType()) {
|
||||
getOperation()->emitError(
|
||||
"source element type does not match result pointee type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// It is only safe to directly convert a pointer to a memref if the memref
|
||||
// has no offset.
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (src_type.getStridesAndOffset(strides, offset).failed()) {
|
||||
getOperation()->emitError("failed to get strides and offset") << src_type;
|
||||
return failure();
|
||||
}
|
||||
if (offset != 0) {
|
||||
getOperation()->emitError("memref has non-zero offset");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult PtrToMemrefOp::verify() {
|
||||
mlir::MemRefType result_type = getType();
|
||||
if (getSrc().getType().getPointeeType() != result_type.getElementType()) {
|
||||
getOperation()->emitError(
|
||||
"source pointee type does not match result element type");
|
||||
return failure();
|
||||
}
|
||||
|
||||
// It is only safe to directly convert a pointer to a memref if the memref
|
||||
// has no offset.
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset = 0;
|
||||
if (result_type.getStridesAndOffset(strides, offset).failed()) {
|
||||
getOperation()->emitError("failed to get strides and offset")
|
||||
<< result_type;
|
||||
return failure();
|
||||
}
|
||||
if (offset != 0) {
|
||||
getOperation()->emitError("memref has non-zero offset");
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace mlir::triton::xla
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
|||
|
|
@ -421,6 +421,41 @@ def TTXLA_GetPeerPtrOp : TTXLA_Op<"get_peer_ptr", [Pure]> {
|
|||
}];
|
||||
}
|
||||
|
||||
def TTXLA_MemrefToPtrOp : TTXLA_Op<"memref_to_ptr", [Pure]> {
|
||||
let summary = [{
|
||||
A specialized version of unrealized_conversion_cast that converts a
|
||||
memref to a pointer.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyMemRef:$src);
|
||||
|
||||
let results = (outs TT_Ptr:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$src `from` type($src) `to` type($result) attr-dict
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TTXLA_PtrToMemrefOp : TTXLA_Op<"ptr_to_memref", [Pure]> {
|
||||
let summary = [{
|
||||
A specialized version of unrealized_conversion_cast that converts a
|
||||
pointer to a memref.
|
||||
}];
|
||||
|
||||
let arguments = (ins TT_Ptr:$src);
|
||||
|
||||
let results = (outs AnyMemRef:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$src `from` type($src) `to` type($result) attr-dict
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_TD_
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ cc_library(
|
|||
"triton_xla_lower_block_barrier_pass.cc",
|
||||
"triton_xla_lower_get_tid_pass.cc",
|
||||
"triton_xla_lower_remote_access_pass.cc",
|
||||
"triton_xla_lower_xtile_pass.cc",
|
||||
"triton_xla_squeeze_dims_pass.cc",
|
||||
"triton_xla_unswitch_loops_pass.cc",
|
||||
],
|
||||
|
|
@ -56,6 +57,8 @@ cc_library(
|
|||
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
|
||||
"//xla/codegen:emitter_loc_op_builder",
|
||||
"//xla/codegen/emitters/ir:xla",
|
||||
"//xla/codegen/xtile/ir:xtile",
|
||||
"//xla/service/gpu:ir_emission_utils",
|
||||
"//xla/service/llvm_ir:llvm_util",
|
||||
"//xla/stream_executor/gpu:collective_kernel_metadata",
|
||||
"//xla/stream_executor/gpu:tma_metadata",
|
||||
|
|
@ -63,6 +66,7 @@ cc_library(
|
|||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
|
|
@ -73,7 +77,10 @@ cc_library(
|
|||
"@llvm-project//mlir:FuncDialect",
|
||||
"@llvm-project//mlir:FunctionInterfaces",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InliningUtils",
|
||||
"@llvm-project//mlir:LLVMCommonConversion",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
"@llvm-project//mlir:MemRefDialect",
|
||||
"@llvm-project//mlir:NVVMDialect",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Rewrite",
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" // IWYU pragma: keep
|
||||
#include "xla/codegen/xtile/ir/xtile_dialect.h" // IWYU pragma: keep
|
||||
|
||||
namespace mlir::triton::xla {
|
||||
|
||||
|
|
@ -45,6 +46,7 @@ std::unique_ptr<mlir::Pass> CreateTritonXLALowerAtomicsPass();
|
|||
std::unique_ptr<mlir::Pass> CreateTritonXLALowerBlockBarrierPass();
|
||||
std::unique_ptr<mlir::Pass> CreateTritonXLAConvertUnsupportedTypesPass();
|
||||
std::unique_ptr<mlir::Pass> CreateTritonXLALowerRemoteAccessPass();
|
||||
std::unique_ptr<mlir::Pass> CreateTritonXLALowerXTilePass();
|
||||
std::unique_ptr<mlir::Pass> CreateStableHLOLowerToTritonPass();
|
||||
std::unique_ptr<mlir::Pass> CreateTensorLowerToTritonPass();
|
||||
|
||||
|
|
|
|||
|
|
@ -190,6 +190,20 @@ def TritonXLAUnswitchLoopsPass :
|
|||
let constructor = "CreateTritonXLAUnswitchLoopsPass()";
|
||||
}
|
||||
|
||||
def TritonXLALowerXTilePass :
|
||||
Pass<"triton-xla-lower-xtile", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers xtile ops to Triton ops.";
|
||||
|
||||
let dependentDialects = [
|
||||
"mlir::func::FuncDialect",
|
||||
"mlir::memref::MemRefDialect",
|
||||
"mlir::triton::xla::XlaTritonDialect",
|
||||
"triton::TritonDialect",
|
||||
];
|
||||
|
||||
let constructor = "CreateTritonXLALowerXTilePass()";
|
||||
}
|
||||
|
||||
def StableHLOLowerToTritonPass
|
||||
: Pass<"stablehlo-lower-to-triton", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers StableHLO operations to their Triton equivalent.";
|
||||
|
|
|
|||
39
third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_lower_xtile.mlir
vendored
Normal file
39
third_party/xla/xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_lower_xtile.mlir
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// RUN: xla-opt %s -split-input-file -triton-xla-lower-xtile | FileCheck %s
|
||||
|
||||
xtile.entry_func @extract_insert_no_layout(%input: memref<1024xf32, #nvvm.memory_space<global>>,
|
||||
%output: memref<32xf32, #nvvm.memory_space<global>>,
|
||||
%tile_id: index) {
|
||||
%tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32, #nvvm.memory_space<global>> -> tensor<1xf32>
|
||||
xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<32xf32, #nvvm.memory_space<global>>
|
||||
xtile.return
|
||||
}
|
||||
|
||||
// CHECK: func.func @extract_insert_no_layout(%[[ARG0:.*]]: !tt.ptr<f32>, %[[ARG1:.*]]: !tt.ptr<f32>) {
|
||||
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
|
||||
// CHECK: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64
|
||||
// CHECK: %[[PID_IDX:.*]] = arith.index_cast %[[PID_I64]] : i64 to index
|
||||
// CHECK: %[[TILE:.*]] = triton_xla.extract from %[[ARG0]] as memref<1024xf32, #triton_xla.layout<[0]>> [%[[PID_IDX]]] [1] [1] : tensor<1xf32>
|
||||
// CHECK: triton_xla.insert %[[TILE]] into %[[ARG1]] as memref<32xf32, #triton_xla.layout<[0]>> [%[[PID_IDX]]] [1] [1] : tensor<1xf32>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
!arg_type = memref<1024x32x1x1xbf16, #triton_xla.layout<[2, 3, 0, 1]>, #nvvm.memory_space<global>>
|
||||
xtile.entry_func @layout_preserved(%input: !arg_type,
|
||||
%tile_id: index) {
|
||||
%c_0 = arith.constant 0 : index
|
||||
%tile = xtile.extract %input[%tile_id, %c_0, %c_0, %c_0][1, 1, 1, 1][1, 1, 1, 1] : !arg_type -> tensor<1x1x1x1xbf16>
|
||||
xtile.return
|
||||
}
|
||||
|
||||
// CHECK: func.func @layout_preserved(%[[ARG0:.*]]: !tt.ptr<bf16>) {
|
||||
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
|
||||
// CHECK: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64
|
||||
// CHECK: %[[PID_IDX:.*]] = arith.index_cast %[[PID_I64]] : i64 to index
|
||||
// CHECK: %[[TILE:.*]] = triton_xla.extract from %[[ARG0]]
|
||||
// CHECK-SAME: as memref<1024x32x1x1xbf16, #triton_xla.layout<[3, 2, 0, 1]>>
|
||||
// CHECK-SAME: [%[[PID_IDX]], 0, 0, 0]
|
||||
// CHECK-SAME: [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xbf16>
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
291
third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_xtile_pass.cc
vendored
Normal file
291
third_party/xla/xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_xtile_pass.cc
vendored
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
/* Copyright 2025 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypeInterfaces.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/IR/ValueRange.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Inliner.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
|
||||
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"
|
||||
#include "xla/codegen/xtile/ir/xtile_ops.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
|
||||
namespace mlir::triton::xla {
|
||||
|
||||
#define GEN_PASS_DEF_TRITONXLALOWERXTILEPASS
|
||||
#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace ttir = ::mlir::triton;
|
||||
namespace ma = ::mlir::arith;
|
||||
|
||||
// Get the new arg types of the lowered function by translating memrefs to the
|
||||
// corresponding pointer types.
|
||||
llvm::SmallVector<mlir::Type> GetPtrArgTypes(mlir::ValueRange args) {
|
||||
llvm::SmallVector<mlir::Type> arg_types;
|
||||
arg_types.reserve(args.size());
|
||||
for (auto arg : args) {
|
||||
mlir::MemRefType memref_type = mlir::cast<mlir::MemRefType>(arg.getType());
|
||||
arg_types.push_back(::xla::gpu::triton::GetPointerType(memref_type));
|
||||
}
|
||||
return arg_types;
|
||||
}
|
||||
|
||||
// Function to get the permutation vector from a MemRefType.
|
||||
// The motivation for extracting it from getStridesAndOffset vs directly from
|
||||
// triton_xla.layout is that when we fold memrefs (such as in a transpose) it
|
||||
// will have a generic strided layout that does not directly encode the
|
||||
// permutation.
|
||||
absl::StatusOr<llvm::SmallVector<int64_t>> getPermutationMinorToMajor(
|
||||
mlir::MemRefType memref) {
|
||||
llvm::SmallVector<int64_t> strides;
|
||||
int64_t offset;
|
||||
if (memref.getStridesAndOffset(strides, offset).failed()) {
|
||||
// This can fail if the layout is not strided (e.g., has dynamic strides).
|
||||
return absl::InvalidArgumentError("Failed to get strides and offset");
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t> permutation;
|
||||
permutation.resize(strides.size());
|
||||
absl::c_iota(permutation, 0);
|
||||
|
||||
absl::c_sort(permutation, [&](int64_t lhs_dim, int64_t rhs_dim) {
|
||||
int64_t lhs_stride = strides[lhs_dim];
|
||||
int64_t rhs_stride = strides[rhs_dim];
|
||||
if (lhs_stride != rhs_stride) {
|
||||
return lhs_stride < rhs_stride;
|
||||
}
|
||||
|
||||
// If the strides are the same, we need to ensure that the unit dimension is
|
||||
// the more minor.
|
||||
int64_t lhs_size = memref.getDimSize(lhs_dim);
|
||||
int64_t rhs_size = memref.getDimSize(rhs_dim);
|
||||
if (lhs_size != rhs_size) {
|
||||
return lhs_size < rhs_size;
|
||||
}
|
||||
|
||||
// If all else fails just sort in the canonical order.
|
||||
return lhs_dim > rhs_dim;
|
||||
});
|
||||
|
||||
// Check that the strides actually represent a permutation,
|
||||
// this could happen for example with padded buffers.
|
||||
int64_t size_product = 1;
|
||||
for (int64_t dim : permutation) {
|
||||
if (strides[dim] != size_product) {
|
||||
return absl::InvalidArgumentError("Layout is not a valid permutation");
|
||||
}
|
||||
size_product *= memref.getDimSize(dim);
|
||||
}
|
||||
|
||||
return permutation;
|
||||
}
|
||||
|
||||
MemrefToPtrOp CreateMemrefToPtr(mlir::OpBuilder& builder,
|
||||
mlir::TypedValue<mlir::MemRefType> memref) {
|
||||
mlir::MemRefType memref_type = memref.getType();
|
||||
return builder.create<MemrefToPtrOp>(
|
||||
memref.getLoc(), ::xla::gpu::triton::GetPointerType(memref_type), memref);
|
||||
}
|
||||
|
||||
// Rewrite a xtile entry to a func.func with the same body, but with memref
|
||||
// arguments replaced by pointers.
|
||||
class XTileEntryToTriton
|
||||
: public mlir::OpRewritePattern<::xla::xtile::EntryFuncOp> {
|
||||
public:
|
||||
XTileEntryToTriton(mlir::MLIRContext* context, mlir::ModuleOp& module)
|
||||
: OpRewritePattern(context), module_(module) {}
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
::xla::xtile::EntryFuncOp entry_op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
mlir::ImplicitLocOpBuilder builder(module_->getLoc(), module_);
|
||||
builder.setInsertionPointToStart(module_.getBody());
|
||||
|
||||
auto new_arg_types = GetPtrArgTypes(entry_op.getBufferArgs());
|
||||
auto new_func_op = builder.create<mlir::func::FuncOp>(
|
||||
entry_op.getName(), builder.getFunctionType(new_arg_types, {}));
|
||||
|
||||
// Move the old function's body to the new function
|
||||
rewriter.inlineRegionBefore(
|
||||
entry_op.getBody(), new_func_op.getFunctionBody(), new_func_op.end());
|
||||
|
||||
Block& entry_block = new_func_op.front();
|
||||
builder.setInsertionPointToStart(&entry_block);
|
||||
|
||||
SmallVector<BlockArgument> old_args(entry_block.getArguments());
|
||||
SmallVector<BlockArgument> new_args(entry_block.addArguments(
|
||||
new_arg_types,
|
||||
SmallVector<Location>(new_arg_types.size(), entry_op.getLoc())));
|
||||
|
||||
BlockArgument tile_id_arg = old_args.back();
|
||||
|
||||
// TODO(b/389955087): we can decide whether to sign extend by
|
||||
// understanding if we need 64 bits to encode indices or if 32 bits are
|
||||
// enough. For now, just use 64 bits to avoid issues.
|
||||
auto pid = builder.create<ttir::GetProgramIdOp>(ttir::ProgramIDDim::X);
|
||||
Value pid_i64 = builder.create<ma::ExtSIOp>(builder.getI64Type(), pid);
|
||||
Value pid_idx =
|
||||
builder.create<ma::IndexCastOp>(builder.getIndexType(), pid_i64);
|
||||
rewriter.replaceAllUsesWith(tile_id_arg, pid_idx);
|
||||
|
||||
// Handle memeref arguments.
|
||||
for (auto [old_arg, new_arg] : llvm::zip(old_args, new_args)) {
|
||||
mlir::MemRefType memref_type =
|
||||
mlir::cast<mlir::MemRefType>(old_arg.getType());
|
||||
|
||||
mlir::Value memref_cast =
|
||||
builder.create<PtrToMemrefOp>(memref_type, new_arg);
|
||||
|
||||
// Replace all uses of the old argument with the result of the cast.
|
||||
rewriter.replaceAllUsesWith(old_arg, memref_cast);
|
||||
}
|
||||
|
||||
entry_block.eraseArguments(0, old_args.size());
|
||||
|
||||
rewriter.setInsertionPointToEnd(&entry_block);
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
|
||||
entry_block.getTerminator());
|
||||
|
||||
rewriter.eraseOp(entry_op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::ModuleOp& module_;
|
||||
};
|
||||
|
||||
// Rewrite a xtile extract to a triton_xla extract.
|
||||
class XTileExtractToTriton
|
||||
: public mlir::OpRewritePattern<::xla::xtile::ExtractTileOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
::xla::xtile::ExtractTileOp extract_op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
mlir::MemRefType source_type = extract_op.getSource().getType();
|
||||
mlir::RankedTensorType result_type = extract_op.getType();
|
||||
|
||||
mlir::Value memref_to_ptr =
|
||||
CreateMemrefToPtr(rewriter, extract_op.getSource());
|
||||
|
||||
absl::StatusOr<SmallVector<int64_t>> minor_to_major_or =
|
||||
getPermutationMinorToMajor(source_type);
|
||||
if (!minor_to_major_or.ok()) {
|
||||
return rewriter.notifyMatchFailure(extract_op,
|
||||
minor_to_major_or.status().ToString());
|
||||
}
|
||||
const SmallVector<int64_t>& minor_to_major = *minor_to_major_or;
|
||||
auto triton_extract_op = rewriter.create<ExtractOp>(
|
||||
extract_op.getLoc(), result_type, memref_to_ptr,
|
||||
extract_op.getOffsets(), extract_op.getFullTileShape(),
|
||||
extract_op.getStrides(), source_type.getShape(), minor_to_major);
|
||||
|
||||
rewriter.replaceOp(extract_op, triton_extract_op);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
// Rewrite a xtile insert to a triton_xla insert.
|
||||
class XTileInsertToTriton
|
||||
: public mlir::OpRewritePattern<::xla::xtile::InsertTileOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
::xla::xtile::InsertTileOp insert_op,
|
||||
mlir::PatternRewriter& rewriter) const override {
|
||||
mlir::MemRefType destination_type = insert_op.getDestination().getType();
|
||||
|
||||
mlir::Value memref_to_ptr =
|
||||
CreateMemrefToPtr(rewriter, insert_op.getDestination());
|
||||
|
||||
absl::StatusOr<SmallVector<int64_t>> minor_to_major_or =
|
||||
getPermutationMinorToMajor(destination_type);
|
||||
if (!minor_to_major_or.ok()) {
|
||||
return rewriter.notifyMatchFailure(insert_op,
|
||||
minor_to_major_or.status().ToString());
|
||||
}
|
||||
const SmallVector<int64_t>& minor_to_major = *minor_to_major_or;
|
||||
auto triton_insert_op = rewriter.create<InsertOp>(
|
||||
insert_op.getLoc(), insert_op.getSource(), memref_to_ptr,
|
||||
insert_op.getOffsets(), insert_op.getFullTileShape(),
|
||||
insert_op.getStrides(), destination_type.getShape(), minor_to_major);
|
||||
|
||||
rewriter.replaceOp(insert_op, triton_insert_op);
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonXLALowerXTilePass
|
||||
: public impl::TritonXLALowerXTilePassBase<TritonXLALowerXTilePass> {
|
||||
public:
|
||||
using TritonXLALowerXTilePassBase::TritonXLALowerXTilePassBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
mlir::ModuleOp module = getOperation();
|
||||
mlir::MLIRContext* context = &getContext();
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
|
||||
patterns.add<XTileEntryToTriton>(context, module);
|
||||
patterns.add<XTileExtractToTriton, XTileInsertToTriton>(context);
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<Pass> CreateTritonXLALowerXTilePass() {
|
||||
return std::make_unique<TritonXLALowerXTilePass>();
|
||||
}
|
||||
|
||||
} // namespace mlir::triton::xla
|
||||
1
third_party/xla/xla/service/gpu/tests/BUILD
vendored
1
third_party/xla/xla/service/gpu/tests/BUILD
vendored
|
|
@ -699,6 +699,7 @@ lit_test_suite_for_gpus(
|
|||
# "//xla/backends/gpu/codegen/triton/transforms:passes",
|
||||
# "//xla/codegen/emitters/ir:xla",
|
||||
# "//xla/codegen/emitters/transforms:passes",
|
||||
# "//xla/codegen/xtile/ir:xtile",
|
||||
# "//xla/stream_executor:device_description",
|
||||
# "//xla/stream_executor/cuda:cuda_compute_capability",
|
||||
# "@triton//:AllPassesAndDialects",
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
|
||||
#include "xla/codegen/emitters/ir/xla_dialect.h"
|
||||
#include "xla/codegen/emitters/transforms/passes.h"
|
||||
#include "xla/codegen/xtile/ir/xtile_dialect.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
#include "xla/stream_executor/device_description.h"
|
||||
#include "third_party/triton/bin/RegisterTritonDialects.h"
|
||||
|
|
@ -86,9 +87,10 @@ int main(int argc, char** argv) {
|
|||
mlir::LLVM::registerInlinerInterface(registry);
|
||||
mlir::func::registerInlinerExtension(registry);
|
||||
registerTritonDialects(registry); // This registers all passes as well.
|
||||
registry.insert<mlir::func::FuncDialect, mlir::tensor::TensorDialect,
|
||||
mlir::triton::xla::XlaTritonDialect, xla::XlaDialect,
|
||||
mlir::stablehlo::StablehloDialect>();
|
||||
registry
|
||||
.insert<mlir::func::FuncDialect, mlir::tensor::TensorDialect,
|
||||
mlir::triton::xla::XlaTritonDialect, xla::XlaDialect,
|
||||
xla::xtile::XTileDialect, mlir::stablehlo::StablehloDialect>();
|
||||
mlir::triton::xla::registerTritonXlaTransformsPasses();
|
||||
xla::emitters::registerTransformsPasses();
|
||||
xla::gpu::registerGpuFusionTransformsPasses();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user