mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
[XLA:CPU] Fix UnrollExtractLoops to stop false-positive unroll when the vector itself is dependent on for loop.
PiperOrigin-RevId: 825427562
This commit is contained in:
parent
f19413ab39
commit
a8356f63df
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
|
@ -39,6 +40,7 @@ limitations under the License.
|
|||
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/WalkResult.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
||||
|
||||
|
|
@ -50,6 +52,38 @@ namespace xla::cpu {
|
|||
|
||||
namespace {
|
||||
|
||||
// Check if the given extract operands depends on the given value.
|
||||
bool ExtractDependsOnValue(mlir::vector::ExtractOp extract_op,
|
||||
mlir::Value value) {
|
||||
mlir::BackwardSliceOptions backward_slice_options;
|
||||
backward_slice_options.omitUsesFromAbove = false;
|
||||
backward_slice_options.inclusive = true;
|
||||
|
||||
for (mlir::Value dynamic_index : extract_op.getDynamicPosition()) {
|
||||
// We have to explicitly check the index itself as getBackwardSlice only
|
||||
// starts from the defining operation.
|
||||
if (dynamic_index == value) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::SetVector<mlir::Operation*> backwardSlice;
|
||||
if (mlir::failed(mlir::getBackwardSlice(dynamic_index, &backwardSlice,
|
||||
backward_slice_options))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (mlir::Operation* op : backwardSlice) {
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
if (operand == value) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
struct FoldExtractIntoTransferRead
|
||||
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
|
@ -197,23 +231,27 @@ struct UnrollExtractLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
|
|||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
mlir::LogicalResult matchAndRewrite(
|
||||
mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
|
||||
if (op.getRegion().getOps<mlir::vector::ExtractOp>().empty()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"loop does not contain an extract");
|
||||
}
|
||||
|
||||
llvm::SetVector<mlir::Operation*> slices;
|
||||
mlir::getForwardSlice(op.getInductionVar(), &slices);
|
||||
|
||||
for (auto slice : slices) {
|
||||
if (mlir::isa<mlir::vector::ExtractOp>(slice)) {
|
||||
return mlir::loopUnrollFull(op);
|
||||
mlir::scf::ForOp for_op, mlir::PatternRewriter& rewriter) const override {
|
||||
std::optional<mlir::LogicalResult> unroll_result;
|
||||
// Walk the body of the loop and unroll if we have a dependent extract.
|
||||
for_op.getBody()->walk([&](mlir::vector::ExtractOp extract) {
|
||||
if (!ExtractDependsOnValue(extract, for_op.getInductionVar())) {
|
||||
return mlir::WalkResult::advance();
|
||||
}
|
||||
unroll_result = mlir::loopUnrollFull(for_op);
|
||||
return mlir::WalkResult::interrupt();
|
||||
});
|
||||
|
||||
if (!unroll_result.has_value()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
for_op, "loop does not contain a dependent extract");
|
||||
}
|
||||
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "loop does not contain a dependent extract");
|
||||
if (mlir::failed(*unroll_result)) {
|
||||
return rewriter.notifyMatchFailure(for_op, "failed to unroll loop");
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -102,3 +102,28 @@ func.func @does_not_unroll_independent_vector_extract(%input: vector<8x2xf32>, %
|
|||
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
|
||||
// CHECK: scf.for
|
||||
// CHECK-COUNT-1: vector.extract
|
||||
|
||||
// -----
|
||||
|
||||
// Ensure that we don't unroll loops that have a vector that depends on the for
|
||||
// loop but the index itself does not.
|
||||
func.func @does_not_unroll_only_dependent_vector(
|
||||
%input: memref<8x2xf32>, %arg_index: index) -> vector<2xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%c8 = arith.constant 8 : index
|
||||
%c0_f32 = arith.constant 0. : f32
|
||||
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
|
||||
%result = scf.for %index = %c0 to %c8 step %c2 iter_args(%carry = %init) -> vector<2xf32> {
|
||||
%input_vector = vector.transfer_read %input[%index, %c0], %c0_f32 : memref<8x2xf32>, vector<2x2xf32>
|
||||
%extract = vector.extract %input_vector[0] : vector<2xf32> from vector<2x2xf32>
|
||||
%add = arith.addf %carry, %extract : vector<2xf32>
|
||||
scf.yield %add : vector<2xf32>
|
||||
}
|
||||
return %result : vector<2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func.func @does_not_unroll_only_dependent_vector(
|
||||
// CHECK: scf.for
|
||||
// CHECK-COUNT-1: vector.extract
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user