From f4ebf9d47dec4b8552ae62f677a23a92b16e9886 Mon Sep 17 00:00:00 2001 From: Karlo Basioli Date: Thu, 30 Oct 2025 13:17:07 -0700 Subject: [PATCH] [XLA][codegen] Migrate triton operations that have shared dialect lowerings are implemented for. These were missed in previous commits. Addresses transpose and bitcast. PiperOrigin-RevId: 826158776 --- third_party/xla/xla/backends/gpu/codegen/triton/BUILD | 3 ++- .../xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc | 5 +++-- .../xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD index 58ba353dacb..af24f06275d 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/BUILD +++ b/third_party/xla/xla/backends/gpu/codegen/triton/BUILD @@ -137,8 +137,8 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MathDialect", - "@llvm-project//mlir:NVVMDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@triton//:TritonDialects", ], ) @@ -428,6 +428,7 @@ cc_library( "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@stablehlo//:stablehlo_ops", "@triton//:TritonDialects", ], ) diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc index ba288726708..de8e540b364 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/dot_algorithms.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/backends/gpu/codegen/triton/emitter_helpers.h" #include "xla/codegen/emitter_loc_op_builder.h" #include "xla/hlo/ir/hlo_instruction.h" @@ -160,8 +161,8 @@ absl::StatusOr ScaledDot(EmitterLocOpBuilder b, Value rhs_scale; if (rhs_dot_elem_type != ttir::ScaleDotElemType::BF16) { rhs_scale = Bitcast(b, operands.rhs_scale, b.getI8Type()); - rhs_scale = - b.create(rhs_scale, mlir::ArrayRef{1, 0}); + rhs_scale = b.create( + rhs_scale, b.getDenseI64ArrayAttr({1, 0})); } // make type with the same shape as the scale but with i8 type diff --git a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc index c4bff9fac9d..5bb29f00908 100644 --- a/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc +++ b/third_party/xla/xla/backends/gpu/codegen/triton/emitter_helpers.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" @@ -501,7 +502,7 @@ absl::StatusOr EmitElementwise(EmitterLocOpBuilder& b, mh::ComparisonDirection::NE), inputs[1], inputs[2]); case HloOpcode::kReducePrecision: - return mh::reducePrecision( + return mh::reducePrecision( b.getLoc(), inputs[0], hlo.exponent_bits(), hlo.mantissa_bits(), &b); default: return absl::InvalidArgumentError(