mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA:CPU] Fix UnrollExtractLoops to stop false-positive unroll when the vector itself is dependent on for loop.
PiperOrigin-RevId: 825427562
This commit is contained in:
parent
f19413ab39
commit
a8356f63df
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
|
@ -39,6 +40,7 @@ limitations under the License.
|
||||||
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
#include "mlir/Interfaces/DataLayoutInterfaces.h"
|
||||||
#include "mlir/Pass/Pass.h"
|
#include "mlir/Pass/Pass.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "mlir/Support/WalkResult.h"
|
||||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
#include "xla/backends/cpu/codegen/tiled/transforms/passes.h"
|
||||||
|
|
||||||
|
|
@ -50,6 +52,38 @@ namespace xla::cpu {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Check if the given extract operands depends on the given value.
|
||||||
|
bool ExtractDependsOnValue(mlir::vector::ExtractOp extract_op,
|
||||||
|
mlir::Value value) {
|
||||||
|
mlir::BackwardSliceOptions backward_slice_options;
|
||||||
|
backward_slice_options.omitUsesFromAbove = false;
|
||||||
|
backward_slice_options.inclusive = true;
|
||||||
|
|
||||||
|
for (mlir::Value dynamic_index : extract_op.getDynamicPosition()) {
|
||||||
|
// We have to explicitly check the index itself as getBackwardSlice only
|
||||||
|
// starts from the defining operation.
|
||||||
|
if (dynamic_index == value) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
llvm::SetVector<mlir::Operation*> backwardSlice;
|
||||||
|
if (mlir::failed(mlir::getBackwardSlice(dynamic_index, &backwardSlice,
|
||||||
|
backward_slice_options))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (mlir::Operation* op : backwardSlice) {
|
||||||
|
for (mlir::Value operand : op->getOperands()) {
|
||||||
|
if (operand == value) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
struct FoldExtractIntoTransferRead
|
struct FoldExtractIntoTransferRead
|
||||||
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
|
: mlir::OpRewritePattern<mlir::vector::ExtractOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
@ -197,23 +231,27 @@ struct UnrollExtractLoops : mlir::OpRewritePattern<mlir::scf::ForOp> {
|
||||||
using OpRewritePattern::OpRewritePattern;
|
using OpRewritePattern::OpRewritePattern;
|
||||||
|
|
||||||
mlir::LogicalResult matchAndRewrite(
|
mlir::LogicalResult matchAndRewrite(
|
||||||
mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const override {
|
mlir::scf::ForOp for_op, mlir::PatternRewriter& rewriter) const override {
|
||||||
if (op.getRegion().getOps<mlir::vector::ExtractOp>().empty()) {
|
std::optional<mlir::LogicalResult> unroll_result;
|
||||||
return rewriter.notifyMatchFailure(op,
|
// Walk the body of the loop and unroll if we have a dependent extract.
|
||||||
"loop does not contain an extract");
|
for_op.getBody()->walk([&](mlir::vector::ExtractOp extract) {
|
||||||
}
|
if (!ExtractDependsOnValue(extract, for_op.getInductionVar())) {
|
||||||
|
return mlir::WalkResult::advance();
|
||||||
llvm::SetVector<mlir::Operation*> slices;
|
|
||||||
mlir::getForwardSlice(op.getInductionVar(), &slices);
|
|
||||||
|
|
||||||
for (auto slice : slices) {
|
|
||||||
if (mlir::isa<mlir::vector::ExtractOp>(slice)) {
|
|
||||||
return mlir::loopUnrollFull(op);
|
|
||||||
}
|
}
|
||||||
|
unroll_result = mlir::loopUnrollFull(for_op);
|
||||||
|
return mlir::WalkResult::interrupt();
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!unroll_result.has_value()) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
for_op, "loop does not contain a dependent extract");
|
||||||
}
|
}
|
||||||
|
|
||||||
return rewriter.notifyMatchFailure(
|
if (mlir::failed(*unroll_result)) {
|
||||||
op, "loop does not contain a dependent extract");
|
return rewriter.notifyMatchFailure(for_op, "failed to unroll loop");
|
||||||
|
}
|
||||||
|
|
||||||
|
return mlir::success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -102,3 +102,28 @@ func.func @does_not_unroll_independent_vector_extract(%input: vector<8x2xf32>, %
|
||||||
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
|
// CHECK-LABEL: func.func @does_not_unroll_independent_vector_extract(
|
||||||
// CHECK: scf.for
|
// CHECK: scf.for
|
||||||
// CHECK-COUNT-1: vector.extract
|
// CHECK-COUNT-1: vector.extract
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Ensure that we don't unroll loops that have a vector that depends on the for
|
||||||
|
// loop but the index itself does not.
|
||||||
|
func.func @does_not_unroll_only_dependent_vector(
|
||||||
|
%input: memref<8x2xf32>, %arg_index: index) -> vector<2xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c2 = arith.constant 2 : index
|
||||||
|
%c8 = arith.constant 8 : index
|
||||||
|
%c0_f32 = arith.constant 0. : f32
|
||||||
|
%init = vector.broadcast %c0_f32 : f32 to vector<2xf32>
|
||||||
|
%result = scf.for %index = %c0 to %c8 step %c2 iter_args(%carry = %init) -> vector<2xf32> {
|
||||||
|
%input_vector = vector.transfer_read %input[%index, %c0], %c0_f32 : memref<8x2xf32>, vector<2x2xf32>
|
||||||
|
%extract = vector.extract %input_vector[0] : vector<2xf32> from vector<2x2xf32>
|
||||||
|
%add = arith.addf %carry, %extract : vector<2xf32>
|
||||||
|
scf.yield %add : vector<2xf32>
|
||||||
|
}
|
||||||
|
return %result : vector<2xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @does_not_unroll_only_dependent_vector(
|
||||||
|
// CHECK: scf.for
|
||||||
|
// CHECK-COUNT-1: vector.extract
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user