[XLA][XTile] Add xtile lowering passes for triton.

This enables migrating the triton emitter to use emit xtile entry, insert & extract in the child PR.

The main difference is the memref args in the entry function for which `MemrefToPtr` & `PtrToMemref` were introduced which closely resemble `UnrealizedConversionCastOp` with additional verification and will enable special folding of `memref::TransposeOp`.

PiperOrigin-RevId: 821593545
This commit is contained in:
Will Froom 2025-10-20 04:46:20 -07:00 committed by TensorFlower Gardener
parent ea72bd7e48
commit beb48d90e2
13 changed files with 489 additions and 4 deletions

View File

@ -239,6 +239,7 @@ cc_library(
"//xla/codegen/tiling:tiled_hlo_computation",
"//xla/codegen/tiling:tiled_hlo_fusion_instruction",
"//xla/codegen/tiling:tiled_hlo_instruction",
"//xla/codegen/xtile/ir:xtile",
"//xla/hlo/ir:hlo",
"//xla/service:hlo_module_config",
"//xla/service/gpu/model:block_level_parameters",
@ -353,8 +354,9 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
"//xla/codegen:emitter_loc_op_builder",
"//xla/codegen/tiling:tiled_hlo_computation",
"//xla/codegen/xtile/ir:xtile",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_query",
"//xla/hlo/utils:hlo_traversal",
@ -392,6 +394,7 @@ cc_library(
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MathDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Support",
"@triton//:TritonDialects",
@ -443,6 +446,7 @@ cc_library(
"//xla:autotuning_proto_cc",
"//xla/codegen:emitter_loc_op_builder",
"//xla/codegen/tiling:symbolic_tile_analysis",
"//xla/codegen/xtile/ir:xtile",
"//xla/hlo/ir:hlo",
"//xla/hlo/utils:hlo_traversal",
"//xla/service:hlo_module_config",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "llvm/Support/MathExtras.h"
#include "llvm/TargetParser/Triple.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/Builders.h"
@ -66,6 +67,7 @@ limitations under the License.
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
namespace xla::gpu::triton {
@ -616,4 +618,20 @@ absl::StatusOr<stream_executor::gpu::TmaMetadata> ExtractTmaMetadata(
return tma_metadata;
}
::mlir::triton::PointerType GetPointerType(mlir::MemRefType memref_type) {
int address_space = 0;
mlir::Attribute memory_space_attr = memref_type.getMemorySpace();
if (auto int_memory_space_attr =
mlir::dyn_cast_if_present<mlir::IntegerAttr>(memory_space_attr)) {
address_space = int_memory_space_attr.getInt();
} else if (auto llvm_memory_space_attr = mlir::dyn_cast_if_present<
mlir::LLVM::LLVMAddrSpaceAttrInterface>(memory_space_attr)) {
address_space = llvm_memory_space_attr.getAddressSpace();
}
return ::mlir::triton::PointerType::get(memref_type.getElementType(),
address_space);
}
} // namespace xla::gpu::triton

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "xla/tsl/platform/status.h"
#include "xla/xla.pb.h"
#include "xla/xla_data.pb.h"
#include "triton/Dialect/Triton/IR/Types.h"
namespace xla::gpu::triton {
@ -230,6 +231,10 @@ absl::StatusOr<stream_executor::gpu::TmaMetadata> ExtractTmaMetadata(
absl::StatusOr<stream_executor::ThreadDim> ExtractThreadDims(
mlir::ModuleOp triton_module, mlir::LLVM::LLVMFuncOp func_op);
// Returns the triton pointer type that corresponds to the given memref type,
// i.e. has the same element type and address space.
::mlir::triton::PointerType GetPointerType(mlir::MemRefType memref_type);
} // namespace xla::gpu::triton
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_EMITTER_HELPERS_H_

View File

@ -17,3 +17,12 @@ tt.func @xla_triton_extract_insert(%arg0: !tt.ptr<bf16>, %arg1: index) {
[%c0, %arg1][16, 64][1, 1] {noinline = false} : tensor<16x64xbf16>
tt.return
}
// CHECK-LABEL: @fold_ptr_memref_ptr(
// CHECK-SAME: %[[SRC:.*]]: !tt.ptr<f32>
func.func @fold_ptr_memref_ptr(%src: !tt.ptr<f32>) -> !tt.ptr<f32> {
// CHECK: return %[[SRC]] : !tt.ptr<f32>
%src_ptr = triton_xla.ptr_to_memref %src from !tt.ptr<f32> to memref<256xf32>
%dst = triton_xla.memref_to_ptr %src_ptr from memref<256xf32> to !tt.ptr<f32>
func.return %dst : !tt.ptr<f32>
}

View File

@ -254,6 +254,64 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<InsertOpOffsetsSizesStridesFolder>(context);
}
OpFoldResult MemrefToPtrOp::fold(FoldAdaptor adaptor) {
if (auto ptr_to_memref = getOperand().getDefiningOp<PtrToMemrefOp>()) {
// memref_to_ptr(ptr_to_memref(x)) -> x
return ptr_to_memref.getOperand();
}
return {};
}
LogicalResult MemrefToPtrOp::verify() {
mlir::MemRefType src_type = getSrc().getType();
if (src_type.getElementType() != getType().getPointeeType()) {
getOperation()->emitError(
"source element type does not match result pointee type");
return failure();
}
// It is only safe to directly convert a pointer to a memref if the memref
// has no offset.
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (src_type.getStridesAndOffset(strides, offset).failed()) {
getOperation()->emitError("failed to get strides and offset") << src_type;
return failure();
}
if (offset != 0) {
getOperation()->emitError("memref has non-zero offset");
return failure();
}
return success();
}
LogicalResult PtrToMemrefOp::verify() {
mlir::MemRefType result_type = getType();
if (getSrc().getType().getPointeeType() != result_type.getElementType()) {
getOperation()->emitError(
"source pointee type does not match result element type");
return failure();
}
// It is only safe to directly convert a pointer to a memref if the memref
// has no offset.
llvm::SmallVector<int64_t> strides;
int64_t offset = 0;
if (result_type.getStridesAndOffset(strides, offset).failed()) {
getOperation()->emitError("failed to get strides and offset")
<< result_type;
return failure();
}
if (offset != 0) {
getOperation()->emitError("memref has non-zero offset");
return failure();
}
return success();
}
} // namespace mlir::triton::xla
#define GET_OP_CLASSES

View File

@ -421,6 +421,41 @@ def TTXLA_GetPeerPtrOp : TTXLA_Op<"get_peer_ptr", [Pure]> {
}];
}
def TTXLA_MemrefToPtrOp : TTXLA_Op<"memref_to_ptr", [Pure]> {
let summary = [{
A specialized version of unrealized_conversion_cast that converts a
memref to a pointer.
}];
let arguments = (ins AnyMemRef:$src);
let results = (outs TT_Ptr:$result);
let assemblyFormat = [{
$src `from` type($src) `to` type($result) attr-dict
}];
let hasFolder = 1;
let hasVerifier = 1;
}
def TTXLA_PtrToMemrefOp : TTXLA_Op<"ptr_to_memref", [Pure]> {
let summary = [{
A specialized version of unrealized_conversion_cast that converts a
pointer to a memref.
}];
let arguments = (ins TT_Ptr:$src);
let results = (outs AnyMemRef:$result);
let assemblyFormat = [{
$src `from` type($src) `to` type($result) attr-dict
}];
let hasVerifier = 1;
}
#endif // XLA_BACKENDS_GPU_CODEGEN_TRITON_IR_TRITON_XLA_OPS_TD_

View File

@ -44,6 +44,7 @@ cc_library(
"triton_xla_lower_block_barrier_pass.cc",
"triton_xla_lower_get_tid_pass.cc",
"triton_xla_lower_remote_access_pass.cc",
"triton_xla_lower_xtile_pass.cc",
"triton_xla_squeeze_dims_pass.cc",
"triton_xla_unswitch_loops_pass.cc",
],
@ -56,6 +57,8 @@ cc_library(
"//xla/backends/gpu/codegen/triton/ir:triton_xla",
"//xla/codegen:emitter_loc_op_builder",
"//xla/codegen/emitters/ir:xla",
"//xla/codegen/xtile/ir:xtile",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/llvm_ir:llvm_util",
"//xla/stream_executor/gpu:collective_kernel_metadata",
"//xla/stream_executor/gpu:tma_metadata",
@ -63,6 +66,7 @@ cc_library(
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
@ -73,7 +77,10 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:LLVMCommonConversion",
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:NVVMDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h" // IWYU pragma: keep
#include "xla/codegen/xtile/ir/xtile_dialect.h" // IWYU pragma: keep
namespace mlir::triton::xla {
@ -45,6 +46,7 @@ std::unique_ptr<mlir::Pass> CreateTritonXLALowerAtomicsPass();
std::unique_ptr<mlir::Pass> CreateTritonXLALowerBlockBarrierPass();
std::unique_ptr<mlir::Pass> CreateTritonXLAConvertUnsupportedTypesPass();
std::unique_ptr<mlir::Pass> CreateTritonXLALowerRemoteAccessPass();
std::unique_ptr<mlir::Pass> CreateTritonXLALowerXTilePass();
std::unique_ptr<mlir::Pass> CreateStableHLOLowerToTritonPass();
std::unique_ptr<mlir::Pass> CreateTensorLowerToTritonPass();

View File

@ -190,6 +190,20 @@ def TritonXLAUnswitchLoopsPass :
let constructor = "CreateTritonXLAUnswitchLoopsPass()";
}
def TritonXLALowerXTilePass :
Pass<"triton-xla-lower-xtile", "mlir::ModuleOp"> {
let summary = "Lowers xtile ops to Triton ops.";
let dependentDialects = [
"mlir::func::FuncDialect",
"mlir::memref::MemRefDialect",
"mlir::triton::xla::XlaTritonDialect",
"triton::TritonDialect",
];
let constructor = "CreateTritonXLALowerXTilePass()";
}
def StableHLOLowerToTritonPass
: Pass<"stablehlo-lower-to-triton", "mlir::ModuleOp"> {
let summary = "Lowers StableHLO operations to their Triton equivalent.";

View File

@ -0,0 +1,39 @@
// RUN: xla-opt %s -split-input-file -triton-xla-lower-xtile | FileCheck %s
xtile.entry_func @extract_insert_no_layout(%input: memref<1024xf32, #nvvm.memory_space<global>>,
%output: memref<32xf32, #nvvm.memory_space<global>>,
%tile_id: index) {
%tile = xtile.extract %input[%tile_id][1][1] : memref<1024xf32, #nvvm.memory_space<global>> -> tensor<1xf32>
xtile.insert %tile into %output[%tile_id][1][1] : tensor<1xf32> -> memref<32xf32, #nvvm.memory_space<global>>
xtile.return
}
// CHECK: func.func @extract_insert_no_layout(%[[ARG0:.*]]: !tt.ptr<f32>, %[[ARG1:.*]]: !tt.ptr<f32>) {
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
// CHECK: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64
// CHECK: %[[PID_IDX:.*]] = arith.index_cast %[[PID_I64]] : i64 to index
// CHECK: %[[TILE:.*]] = triton_xla.extract from %[[ARG0]] as memref<1024xf32, #triton_xla.layout<[0]>> [%[[PID_IDX]]] [1] [1] : tensor<1xf32>
// CHECK: triton_xla.insert %[[TILE]] into %[[ARG1]] as memref<32xf32, #triton_xla.layout<[0]>> [%[[PID_IDX]]] [1] [1] : tensor<1xf32>
// CHECK: return
// CHECK: }
// -----
!arg_type = memref<1024x32x1x1xbf16, #triton_xla.layout<[2, 3, 0, 1]>, #nvvm.memory_space<global>>
xtile.entry_func @layout_preserved(%input: !arg_type,
%tile_id: index) {
%c_0 = arith.constant 0 : index
%tile = xtile.extract %input[%tile_id, %c_0, %c_0, %c_0][1, 1, 1, 1][1, 1, 1, 1] : !arg_type -> tensor<1x1x1x1xbf16>
xtile.return
}
// CHECK: func.func @layout_preserved(%[[ARG0:.*]]: !tt.ptr<bf16>) {
// CHECK: %[[PID:.*]] = tt.get_program_id x : i32
// CHECK: %[[PID_I64:.*]] = arith.extsi %[[PID]] : i32 to i64
// CHECK: %[[PID_IDX:.*]] = arith.index_cast %[[PID_I64]] : i64 to index
// CHECK: %[[TILE:.*]] = triton_xla.extract from %[[ARG0]]
// CHECK-SAME: as memref<1024x32x1x1xbf16, #triton_xla.layout<[3, 2, 0, 1]>>
// CHECK-SAME: [%[[PID_IDX]], 0, 0, 0]
// CHECK-SAME: [1, 1, 1, 1] [1, 1, 1, 1] : tensor<1x1x1x1xbf16>
// CHECK: return
// CHECK: }

View File

@ -0,0 +1,291 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include <utility>
#include "absl/algorithm/container.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Inliner.h"
#include "mlir/Transforms/InliningUtils.h"
#include "xla/backends/gpu/codegen/triton/emitter_helpers.h"
#include "xla/backends/gpu/codegen/triton/ir/triton_xla_ops.h"
#include "xla/codegen/xtile/ir/xtile_ops.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
namespace mlir::triton::xla {
#define GEN_PASS_DEF_TRITONXLALOWERXTILEPASS
#include "xla/backends/gpu/codegen/triton/transforms/passes.h.inc"
namespace {
namespace ttir = ::mlir::triton;
namespace ma = ::mlir::arith;
// Get the new arg types of the lowered function by translating memrefs to the
// corresponding pointer types.
llvm::SmallVector<mlir::Type> GetPtrArgTypes(mlir::ValueRange args) {
llvm::SmallVector<mlir::Type> arg_types;
arg_types.reserve(args.size());
for (auto arg : args) {
mlir::MemRefType memref_type = mlir::cast<mlir::MemRefType>(arg.getType());
arg_types.push_back(::xla::gpu::triton::GetPointerType(memref_type));
}
return arg_types;
}
// Function to get the permutation vector from a MemRefType.
// The motivation for extracting it from getStridesAndOffset vs directly from
// triton_xla.layout is that when we fold memrefs (such as in a transpose) it
// will have a generic strided layout that does not directly encode the
// permutation.
absl::StatusOr<llvm::SmallVector<int64_t>> getPermutationMinorToMajor(
mlir::MemRefType memref) {
llvm::SmallVector<int64_t> strides;
int64_t offset;
if (memref.getStridesAndOffset(strides, offset).failed()) {
// This can fail if the layout is not strided (e.g., has dynamic strides).
return absl::InvalidArgumentError("Failed to get strides and offset");
}
llvm::SmallVector<int64_t> permutation;
permutation.resize(strides.size());
absl::c_iota(permutation, 0);
absl::c_sort(permutation, [&](int64_t lhs_dim, int64_t rhs_dim) {
int64_t lhs_stride = strides[lhs_dim];
int64_t rhs_stride = strides[rhs_dim];
if (lhs_stride != rhs_stride) {
return lhs_stride < rhs_stride;
}
// If the strides are the same, we need to ensure that the unit dimension is
// the more minor.
int64_t lhs_size = memref.getDimSize(lhs_dim);
int64_t rhs_size = memref.getDimSize(rhs_dim);
if (lhs_size != rhs_size) {
return lhs_size < rhs_size;
}
// If all else fails just sort in the canonical order.
return lhs_dim > rhs_dim;
});
// Check that the strides actually represent a permutation,
// this could happen for example with padded buffers.
int64_t size_product = 1;
for (int64_t dim : permutation) {
if (strides[dim] != size_product) {
return absl::InvalidArgumentError("Layout is not a valid permutation");
}
size_product *= memref.getDimSize(dim);
}
return permutation;
}
MemrefToPtrOp CreateMemrefToPtr(mlir::OpBuilder& builder,
mlir::TypedValue<mlir::MemRefType> memref) {
mlir::MemRefType memref_type = memref.getType();
return builder.create<MemrefToPtrOp>(
memref.getLoc(), ::xla::gpu::triton::GetPointerType(memref_type), memref);
}
// Rewrite a xtile entry to a func.func with the same body, but with memref
// arguments replaced by pointers.
class XTileEntryToTriton
: public mlir::OpRewritePattern<::xla::xtile::EntryFuncOp> {
public:
XTileEntryToTriton(mlir::MLIRContext* context, mlir::ModuleOp& module)
: OpRewritePattern(context), module_(module) {}
mlir::LogicalResult matchAndRewrite(
::xla::xtile::EntryFuncOp entry_op,
mlir::PatternRewriter& rewriter) const override {
mlir::ImplicitLocOpBuilder builder(module_->getLoc(), module_);
builder.setInsertionPointToStart(module_.getBody());
auto new_arg_types = GetPtrArgTypes(entry_op.getBufferArgs());
auto new_func_op = builder.create<mlir::func::FuncOp>(
entry_op.getName(), builder.getFunctionType(new_arg_types, {}));
// Move the old function's body to the new function
rewriter.inlineRegionBefore(
entry_op.getBody(), new_func_op.getFunctionBody(), new_func_op.end());
Block& entry_block = new_func_op.front();
builder.setInsertionPointToStart(&entry_block);
SmallVector<BlockArgument> old_args(entry_block.getArguments());
SmallVector<BlockArgument> new_args(entry_block.addArguments(
new_arg_types,
SmallVector<Location>(new_arg_types.size(), entry_op.getLoc())));
BlockArgument tile_id_arg = old_args.back();
// TODO(b/389955087): we can decide whether to sign extend by
// understanding if we need 64 bits to encode indices or if 32 bits are
// enough. For now, just use 64 bits to avoid issues.
auto pid = builder.create<ttir::GetProgramIdOp>(ttir::ProgramIDDim::X);
Value pid_i64 = builder.create<ma::ExtSIOp>(builder.getI64Type(), pid);
Value pid_idx =
builder.create<ma::IndexCastOp>(builder.getIndexType(), pid_i64);
rewriter.replaceAllUsesWith(tile_id_arg, pid_idx);
// Handle memeref arguments.
for (auto [old_arg, new_arg] : llvm::zip(old_args, new_args)) {
mlir::MemRefType memref_type =
mlir::cast<mlir::MemRefType>(old_arg.getType());
mlir::Value memref_cast =
builder.create<PtrToMemrefOp>(memref_type, new_arg);
// Replace all uses of the old argument with the result of the cast.
rewriter.replaceAllUsesWith(old_arg, memref_cast);
}
entry_block.eraseArguments(0, old_args.size());
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
entry_block.getTerminator());
rewriter.eraseOp(entry_op);
return success();
}
private:
mlir::ModuleOp& module_;
};
// Rewrite a xtile extract to a triton_xla extract.
class XTileExtractToTriton
: public mlir::OpRewritePattern<::xla::xtile::ExtractTileOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
::xla::xtile::ExtractTileOp extract_op,
mlir::PatternRewriter& rewriter) const override {
mlir::MemRefType source_type = extract_op.getSource().getType();
mlir::RankedTensorType result_type = extract_op.getType();
mlir::Value memref_to_ptr =
CreateMemrefToPtr(rewriter, extract_op.getSource());
absl::StatusOr<SmallVector<int64_t>> minor_to_major_or =
getPermutationMinorToMajor(source_type);
if (!minor_to_major_or.ok()) {
return rewriter.notifyMatchFailure(extract_op,
minor_to_major_or.status().ToString());
}
const SmallVector<int64_t>& minor_to_major = *minor_to_major_or;
auto triton_extract_op = rewriter.create<ExtractOp>(
extract_op.getLoc(), result_type, memref_to_ptr,
extract_op.getOffsets(), extract_op.getFullTileShape(),
extract_op.getStrides(), source_type.getShape(), minor_to_major);
rewriter.replaceOp(extract_op, triton_extract_op);
return mlir::success();
}
};
// Rewrite a xtile insert to a triton_xla insert.
class XTileInsertToTriton
: public mlir::OpRewritePattern<::xla::xtile::InsertTileOp> {
public:
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult matchAndRewrite(
::xla::xtile::InsertTileOp insert_op,
mlir::PatternRewriter& rewriter) const override {
mlir::MemRefType destination_type = insert_op.getDestination().getType();
mlir::Value memref_to_ptr =
CreateMemrefToPtr(rewriter, insert_op.getDestination());
absl::StatusOr<SmallVector<int64_t>> minor_to_major_or =
getPermutationMinorToMajor(destination_type);
if (!minor_to_major_or.ok()) {
return rewriter.notifyMatchFailure(insert_op,
minor_to_major_or.status().ToString());
}
const SmallVector<int64_t>& minor_to_major = *minor_to_major_or;
auto triton_insert_op = rewriter.create<InsertOp>(
insert_op.getLoc(), insert_op.getSource(), memref_to_ptr,
insert_op.getOffsets(), insert_op.getFullTileShape(),
insert_op.getStrides(), destination_type.getShape(), minor_to_major);
rewriter.replaceOp(insert_op, triton_insert_op);
return mlir::success();
}
};
class TritonXLALowerXTilePass
: public impl::TritonXLALowerXTilePassBase<TritonXLALowerXTilePass> {
public:
using TritonXLALowerXTilePassBase::TritonXLALowerXTilePassBase;
void runOnOperation() override {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext* context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<XTileEntryToTriton>(context, module);
patterns.add<XTileExtractToTriton, XTileInsertToTriton>(context);
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns)))) {
signalPassFailure();
return;
}
}
};
} // namespace
std::unique_ptr<Pass> CreateTritonXLALowerXTilePass() {
return std::make_unique<TritonXLALowerXTilePass>();
}
} // namespace mlir::triton::xla

View File

@ -699,6 +699,7 @@ lit_test_suite_for_gpus(
# "//xla/backends/gpu/codegen/triton/transforms:passes",
# "//xla/codegen/emitters/ir:xla",
# "//xla/codegen/emitters/transforms:passes",
# "//xla/codegen/xtile/ir:xtile",
# "//xla/stream_executor:device_description",
# "//xla/stream_executor/cuda:cuda_compute_capability",
# "@triton//:AllPassesAndDialects",

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "xla/backends/gpu/codegen/triton/transforms/passes.h"
#include "xla/codegen/emitters/ir/xla_dialect.h"
#include "xla/codegen/emitters/transforms/passes.h"
#include "xla/codegen/xtile/ir/xtile_dialect.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "third_party/triton/bin/RegisterTritonDialects.h"
@ -86,9 +87,10 @@ int main(int argc, char** argv) {
mlir::LLVM::registerInlinerInterface(registry);
mlir::func::registerInlinerExtension(registry);
registerTritonDialects(registry); // This registers all passes as well.
registry.insert<mlir::func::FuncDialect, mlir::tensor::TensorDialect,
mlir::triton::xla::XlaTritonDialect, xla::XlaDialect,
mlir::stablehlo::StablehloDialect>();
registry
.insert<mlir::func::FuncDialect, mlir::tensor::TensorDialect,
mlir::triton::xla::XlaTritonDialect, xla::XlaDialect,
xla::xtile::XTileDialect, mlir::stablehlo::StablehloDialect>();
mlir::triton::xla::registerTritonXlaTransformsPasses();
xla::emitters::registerTransformsPasses();
xla::gpu::registerGpuFusionTransformsPasses();