From 1faa5856a0711a980cdc777e75b4cfe9df00d86c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Nov 2023 04:38:22 -0800 Subject: [PATCH] Integrate LLVM at llvm/llvm-project@c39995a116a7 Updates LLVM usage to match [c39995a116a7](https://github.com/llvm/llvm-project/commit/c39995a116a7) PiperOrigin-RevId: 580134851 --- .../mlir/tosa/tests/tfl-to-tosa-pipeline.mlir | 2 +- third_party/llvm/generated.patch | 12 ++++++++++++ third_party/llvm/workspace.bzl | 4 ++-- .../transforms/vectorization/vectorize_for_cpu.cc | 5 ++++- third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.td | 10 ++++++---- .../xla/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir | 3 ++- .../tests/collapse_parallel_loops_to_1d_pass.mlir | 3 ++- third_party/xla/xla/mlir_hlo/tests/tile_loops.mlir | 6 ++++-- third_party/xla/xla/mlir_hlo/tests/unbufferize.mlir | 8 +++++--- third_party/xla/xla/mlir_hlo/transforms/passes.td | 5 +++-- .../xla/xla/mlir_hlo/transforms/unbufferize_pass.cc | 6 +++--- third_party/xla/xla/service/gpu/BUILD | 1 + third_party/xla/xla/service/gpu/fusions/fusions.cc | 4 ++-- .../xla/xla/service/gpu/ir_emission_utils.cc | 6 +++--- .../xla/xla/service/gpu/ir_emitter_unnested.cc | 13 +++++++------ .../mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc | 5 ++++- .../tests/non_identity_layouts.hlotxt | 3 ++- 17 files changed, 63 insertions(+), 33 deletions(-) diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 84e61d0de2f..847b3ccfda8 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -2662,7 +2662,7 @@ func.func @test_reverse_fail(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> // CHECK-LABEL: test_tfl_custom // CHECK-SAME: %[[ARG_0:.*]]: tensor<1x64x64x32xf32> -// CHECK: %[[VAL_0:.*]] = tosa.custom %[[ARG_0]] {config = "TFL", identifier = "MaxPoolingWithArgmax2D", implementation_attrs = "{{.*}}"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) +// CHECK: %[[VAL_0:.*]] = tosa.custom %[[ARG_0]] {domain_name = "TFL", implementation_attrs = "{{.*}}", operator_name = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) { // custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) %0, %1 = "tfl.custom"(%arg0) {custom_option = #tfl, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index 509398da979..07eee148761 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1 +1,13 @@ Auto generated patch. Do not edit or delete it, even if empty. +diff -ruN --strip-trailing-cr a/llvm/lib/MC/ELFObjectWriter.cpp b/llvm/lib/MC/ELFObjectWriter.cpp +--- a/llvm/lib/MC/ELFObjectWriter.cpp ++++ b/llvm/lib/MC/ELFObjectWriter.cpp +@@ -843,7 +843,7 @@ + uint32_t ChType, uint64_t Size, + SmallVectorImpl &CompressedContents, Align Alignment) { + uint64_t HdrSize = +- is64Bit() ? sizeof(ELF::Elf64_Chdr) : sizeof(ELF::Elf32_Chdr); ++ is64Bit() ? sizeof(ELF::Elf32_Chdr) : sizeof(ELF::Elf64_Chdr); + if (Size <= HdrSize + CompressedContents.size()) + return false; + // Platform specific header is followed by compressed data. diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index fbf1206ed09..05aa52a6041 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "a682a9cfd006c52559387f80398b720d529595d1" - LLVM_SHA256 = "27d7ccf7c59a91af5ff8d74ee9d9086d5aa7bf5c0cfffdab6dcad5278923175a" + LLVM_COMMIT = "c39995a116a74ebafc63648e8f047d13012c4f87" + LLVM_SHA256 = "cde7016c25257c0789ff5faf226ca3d829eeaa2ab5b22c4388ea35b2b6ee9af4" tf_http_archive( name = name, diff --git a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc index 5c7496c4193..2089351d6b0 100644 --- a/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc +++ b/third_party/xla/xla/mlir_hlo/gml_st/transforms/vectorization/vectorize_for_cpu.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "thlo/IR/thlo_ops.h" namespace mlir { @@ -401,7 +402,9 @@ struct VectorizeForCPUPass } // Hoisting transfer_read/transfer_write. - linalg::hoistRedundantVectorTransfersOnTensor(func); + IRRewriter rewriter(func->getContext()); + func.walk( + [&](scf::ForOp forOp) { hoistLoopInvariantSubsets(rewriter, forOp); }); } }; diff --git a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.td b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.td index f014a4ed3aa..e406b700657 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.td +++ b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.td @@ -1406,8 +1406,9 @@ def FusionOp : LHLO_Op<"fusion", [ SmallVector getOutputBuffers() { SmallVector buffers; - for (auto store : getRegion().front().getOps()) { - buffers.push_back(store.getMemref()); + for (auto store : getRegion().front() + .getOps()) { + buffers.push_back(store.getDest()); } return buffers; } @@ -1422,8 +1423,9 @@ def FusionOp : LHLO_Op<"fusion", [ SmallVector getFusionResults() { SmallVector buffers; - for (auto store : getRegion().front().getOps()) { - buffers.push_back(store.getTensor()); + for (auto store : getRegion().front() + .getOps()) { + buffers.push_back(store.getSource()); } return buffers; } diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir index 8be889f81a3..78b64069531 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/lhlo/ops.mlir @@ -398,7 +398,8 @@ func.func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %inpu %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = bufferization.to_tensor %input3 : memref<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> - memref.tensor_store %4, %out : memref<10xf32> + bufferization.materialize_in_destination %4 in writable %out + : (tensor<10xf32>, memref<10xf32>) -> () "lmhlo.terminator"() : () -> () } ) : () -> () func.return diff --git a/third_party/xla/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir b/third_party/xla/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir index 855f7c2e448..a49b1675cb9 100644 --- a/third_party/xla/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/collapse_parallel_loops_to_1d_pass.mlir @@ -15,6 +15,7 @@ func.func @parallel_2d(%arg0: memref<4x4xf32>, %arg1: memref<4x4xf32>) { scf.yield } %1 = bufferization.to_tensor %0 : memref<4x4xf32> - memref.tensor_store %1, %arg1 : memref<4x4xf32> + bufferization.materialize_in_destination %1 in writable %arg1 + : (tensor<4x4xf32>, memref<4x4xf32>) -> () "lmhlo.terminator"() : () -> () } \ No newline at end of file diff --git a/third_party/xla/xla/mlir_hlo/tests/tile_loops.mlir b/third_party/xla/xla/mlir_hlo/tests/tile_loops.mlir index 0f4e2134073..8cfce78ba48 100644 --- a/third_party/xla/xla/mlir_hlo/tests/tile_loops.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/tile_loops.mlir @@ -19,7 +19,8 @@ func.func @parallel_loop(%arg0: memref<16xf32>, %arg1: memref<16xf32>) { scf.yield } %1 = bufferization.to_tensor %0 : memref<16xf32> - memref.tensor_store %1, %arg1 : memref<16xf32> + bufferization.materialize_in_destination %1 in writable %arg1 + : (tensor<16xf32>, memref<16xf32>) -> () "lmhlo.terminator"() : () -> () } @@ -101,6 +102,7 @@ func.func @complex_access(%arg0: memref<16xf32>, %arg1: memref<4xf32>) { scf.yield } %1 = bufferization.to_tensor %0 : memref<4xf32> - memref.tensor_store %1, %arg1 : memref<4xf32> + bufferization.materialize_in_destination %1 in writable %arg1 + : (tensor<4xf32>, memref<4xf32>) -> () "lmhlo.terminator"() : () -> () } diff --git a/third_party/xla/xla/mlir_hlo/tests/unbufferize.mlir b/third_party/xla/xla/mlir_hlo/tests/unbufferize.mlir index e1bfb4cc9e3..706779585d5 100644 --- a/third_party/xla/xla/mlir_hlo/tests/unbufferize.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/unbufferize.mlir @@ -4,7 +4,8 @@ // CHECK-SAME: (%arg0: tensor<8xf32>) -> (tensor<8xf32> {my.attr}) func.func @unbufferize(%arg0: memref<8xf32>, %arg1: memref<8xf32> {my.attr}) { %0 = bufferization.to_tensor %arg0 : memref<8xf32> - memref.tensor_store %0, %arg1 : memref<8xf32> + bufferization.materialize_in_destination %0 in writable %arg1 + : (tensor<8xf32>, memref<8xf32>) -> () // CHECK-NEXT: return %arg0 : tensor<8xf32> return } @@ -14,7 +15,8 @@ func.func @not_block_arg() { %0 = memref.alloc() : memref<8xf32> // CHECK: bufferization.to_tensor %1 = bufferization.to_tensor %0 : memref<8xf32> - // CHECK: memref.tensor_store - memref.tensor_store %1, %0 : memref<8xf32> + // CHECK: bufferization.materialize_in_destination + bufferization.materialize_in_destination %1 in writable %0 + : (tensor<8xf32>, memref<8xf32>) -> () return } diff --git a/third_party/xla/xla/mlir_hlo/transforms/passes.td b/third_party/xla/xla/mlir_hlo/transforms/passes.td index b532491a737..a75696b9de2 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/passes.td +++ b/third_party/xla/xla/mlir_hlo/transforms/passes.td @@ -152,8 +152,9 @@ def GenericHostToLLVMPass : Pass<"generic-host-to-llvm", "ModuleOp"> { def UnbufferizePass : Pass<"unbufferize", "mlir::func::FuncOp"> { let summary = "Unbufferize partially bufferized functions."; let description = [{ - Removes bufferization.to_tensor and memref.tensor_store ops that are the - result of XLA bufferizing during HLO to MHLO transformation. + Removes bufferization.to_tensor and bufferization.materialize_in_destination + ops that are the result of XLA bufferizing during HLO to MHLO + transformation. }]; let constructor = "hlo::createUnbufferizePass()"; } diff --git a/third_party/xla/xla/mlir_hlo/transforms/unbufferize_pass.cc b/third_party/xla/xla/mlir_hlo/transforms/unbufferize_pass.cc index 9a1531df005..e9570004f86 100644 --- a/third_party/xla/xla/mlir_hlo/transforms/unbufferize_pass.cc +++ b/third_party/xla/xla/mlir_hlo/transforms/unbufferize_pass.cc @@ -69,11 +69,11 @@ void UnbufferizePass::runOnOperation() { }); SmallVector results; SmallVector resultAttrs; - funcOp->walk([&](memref::TensorStoreOp op) { - auto arg = op.getMemref().dyn_cast(); + funcOp->walk([&](bufferization::MaterializeInDestinationOp op) { + auto arg = op.getDest().dyn_cast(); if (!arg) return; argsToErase.set(arg.getArgNumber()); - results.push_back(op.getTensor()); + results.push_back(op.getSource()); resultAttrs.push_back(funcOp.getArgAttrDict(arg.getArgNumber())); rewriter.eraseOp(op); }); diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 3c888208374..89908163e99 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -350,6 +350,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:BufferizationDialect", "@llvm-project//mlir:BuiltinToLLVMIRTranslation", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncExtensions", diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index f4b9d4053db..a24a5d55c28 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -40,8 +40,8 @@ bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) { bool seen_instruction = false; for (mlir::Operation& instr : fusion.getRegion().front()) { if (mlir::isa( - &instr)) { + mlir::bufferization::ToTensorOp, + mlir::bufferization::MaterializeInDestinationOp>(&instr)) { continue; } if (seen_instruction) return false; diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index a28628c145d..5977ec443a5 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -538,7 +538,7 @@ bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( } dus_user = *bitcast->user_begin(); } - if (!mlir::isa(dus_user)) { + if (!mlir::isa(dus_user)) { return false; } auto operand = dus.getOperand(); @@ -564,8 +564,8 @@ bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu( q.push(parameter); visited.insert(parameter); // We have already checked above that the DUS only has one user: a - // (possibly bitcasted) TensorStoreOp. So we don't need to visit it during - // the breadth-first search. + // (possibly bitcasted) MaterializeInDestinationOp. So we don't need to + // visit it during the breadth-first search. visited.insert(dus); while (!q.empty()) { auto op = q.front(); diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 5052be1826e..f0ad29ca80a 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1865,7 +1865,7 @@ Status IrEmitterUnnested::EmitTriangularSolveCustomCall(mlir::Operation* op) { // %0 = tensor_load %external_memref0 // %1 = tensor_load %external_memref1 // ... -// tensor_store %ret, %external_memref2 +// materialize_in_destination %ret, %external_memref2 // } // to // fusion(%external_memref0, %external_memref1) (^bb(%0, %1) { @@ -1880,7 +1880,7 @@ static Status ProcessFusionForConversion(mlir::Region* region, std::vector* operand_shapes, std::vector* output_shapes) { std::vector loads; - std::vector stores; + std::vector stores; region->walk([&](mlir::bufferization::ToTensorOp load) { if (load.getMemref().getParentRegion() != region) { @@ -1888,8 +1888,9 @@ static Status ProcessFusionForConversion(mlir::Region* region, } }); - region->walk([&](mlir::memref::TensorStoreOp store) { - if (store.getMemref().getParentRegion() != region) { + region->walk([&](mlir::bufferization::MaterializeInDestinationOp store) { + if (!isa(store.getDest().getType())) return; + if (store.getDest().getParentRegion() != region) { stores.push_back(store); } }); @@ -1904,10 +1905,10 @@ static Status ProcessFusionForConversion(mlir::Region* region, std::vector returned_values; for (auto store : stores) { - Shape shape = GetShape(store.getMemref()); + Shape shape = GetShape(store.getDest()); output_shapes->push_back(shape); - returned_values.push_back(store.getTensor()); + returned_values.push_back(store.getSource()); store.erase(); } diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc index a0c15693741..ac7c223b874 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/mhlo_to_lhlo_with_xla.cc @@ -406,7 +406,10 @@ tsl::StatusOr LhloDialectEmitter::EmitFusionOp( llvm::SmallVector output; TF_RETURN_IF_ERROR(GetOrCreateView(instr, &output)); TF_RETURN_IF_ERROR(WalkTuplePostOrder(result, [&](Value v) mutable { - region_builder.create(loc, v, output[i++]); + auto materialize_op = + region_builder.create( + loc, v, output[i++]); + materialize_op.setWritable(true); return ::tsl::OkStatus(); })); if (i != output.size()) { diff --git a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt index 7967070b27f..d74ec4e3434 100644 --- a/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt +++ b/third_party/xla/xla/translate/mhlo_to_lhlo_with_xla/tests/non_identity_layouts.hlotxt @@ -20,7 +20,8 @@ ENTRY TestComputation { // CHECK-SAME: result_layout = dense<[0, 1]> // CHECK-SAME: xla_shape = "f32[3,2]{0,1}" // CHECK-SAME: } : tensor<3x2xf32> - // CHECK: memref.tensor_store %[[VAL3:.*]], %{{.*}} : memref<3x2xf32, #[[MAP]]> + // CHECK: bufferization.materialize_in_destination %[[VAL3:.*]] in + // CHECK-SAME: writable %{{.*}} : (tensor<3x2xf32>, memref<3x2xf32, #[[MAP]]>) // CHECK: "lmhlo.terminator"() : () -> () // CHECK: }) : () -> () ROOT fusion = f32[3, 2]{0,1} fusion(f32[3, 2]{1,0} x), kind=kLoop, calls=Fusion