mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merge branch 'master' into aarch64_build_patch
This commit is contained in:
commit
10f7d23b9c
8
.bazelrc
8
.bazelrc
|
|
@ -103,9 +103,6 @@ build --define framework_shared_object=true
|
|||
build --java_toolchain=@tf_toolchains//toolchains/java:tf_java_toolchain
|
||||
build --host_java_toolchain=@tf_toolchains//toolchains/java:tf_java_toolchain
|
||||
|
||||
# Do not enable the mlir generated GPU kernels by default.
|
||||
build --//tensorflow/core/kernels/mlir_generated:enable_gpu=false
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
|
|
@ -231,9 +228,6 @@ build:cuda --repo_env TF_NEED_CUDA=1
|
|||
build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain
|
||||
build:cuda --@local_config_cuda//:enable_cuda
|
||||
|
||||
# This option overrides the line above for cuda builds.
|
||||
build:cuda --//tensorflow/core/kernels/mlir_generated:enable_gpu=true
|
||||
|
||||
# This config refers to building CUDA op kernels with clang.
|
||||
build:cuda_clang --config=cuda
|
||||
build:cuda_clang --repo_env TF_CUDA_CLANG=1
|
||||
|
|
@ -254,6 +248,8 @@ build:tensorrt --repo_env TF_NEED_TENSORRT=1
|
|||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --repo_env TF_NEED_ROCM=1
|
||||
# Generated kernels are not yet supported on ROCm.
|
||||
build:rocm --//tensorflow/core/kernels/mlir_generated:enable_gpu=false
|
||||
|
||||
# Options extracted from configure script
|
||||
build:numa --define=with_numa_support=true
|
||||
|
|
|
|||
|
|
@ -137,6 +137,8 @@ class ApproximateTanhLowering
|
|||
loc, CmpFPredicate::ULT, input,
|
||||
rewriter.create<ConstantOp>(
|
||||
loc, rewriter.getF32FloatAttr(-7.90531110763549805f)));
|
||||
Value input_is_nan =
|
||||
rewriter.create<CmpFOp>(loc, CmpFPredicate::UNE, input, input);
|
||||
approx = rewriter.create<SelectOp>(
|
||||
loc, too_large_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(1.0)),
|
||||
|
|
@ -145,6 +147,7 @@ class ApproximateTanhLowering
|
|||
loc, too_small_input,
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getF32FloatAttr(-1.0)),
|
||||
approx);
|
||||
approx = rewriter.create<SelectOp>(loc, input_is_nan, input, approx);
|
||||
|
||||
return approx;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,9 +54,11 @@ func @tanh_f32(%arg0 : f32) -> f32 {
|
|||
// CHECK-DAG: %[[TMP23:.*]] = select %[[TMP22]], %[[ARG]], %[[TMP20]] : f32
|
||||
// CHECK-DAG: %[[TMP24:.*]] = cmpf ugt, %[[ARG]], %[[C11]] : f32
|
||||
// CHECK-DAG: %[[TMP25:.*]] = cmpf ult, %[[ARG]], %[[C12]] : f32
|
||||
// CHECK-DAG: %[[IS_NAN:.*]] = cmpf une, %[[ARG]], %[[ARG]] : f32
|
||||
// CHECK-DAG: %[[TMP26:.*]] = select %[[TMP24]], %[[C13]], %[[TMP23]] : f32
|
||||
// CHECK-DAG: %[[TMP27:.*]] = select %[[TMP25]], %[[C14]], %[[TMP26]] : f32
|
||||
// CHECK: return %[[TMP27]] : f32
|
||||
// CHECK-DAG: %[[RESULT:.*]] = select %[[IS_NAN]], %[[ARG]], %[[TMP27]] : f32
|
||||
// CHECK: return %[[RESULT]] : f32
|
||||
%res = math.tanh %arg0 : f32
|
||||
return %res : f32
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ EOF
|
|||
# TODO(b/178456916): Leverage existing op compat definitions/specs in the
|
||||
# MLIR conversion pipeline in a better way.
|
||||
# TODO(b/180352158): Validate generated TF op names.
|
||||
grep 'patterns.insert<Legalize' $1 | awk -F'<Legalize|>' '{printf " \"%s\",\n", $2}'
|
||||
grep 'patterns.add<Legalize' $1 | awk -F'<Legalize|>' '{printf " \"%s\",\n", $2}'
|
||||
|
||||
cat <<EOF
|
||||
# Rules at tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc
|
||||
|
|
|
|||
|
|
@ -4571,7 +4571,7 @@ subsequent operation and then be optimized away, however.)
|
|||
);
|
||||
}
|
||||
|
||||
def TFL_RFFT2dOp : TFL_Op<"RFFT2D", [NoSideEffect, NoQuantizableResult]> {
|
||||
def TFL_RFFT2dOp : TFL_Op<"rfft2d", [NoSideEffect, NoQuantizableResult]> {
|
||||
let summary = "2D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
|
||||
|
||||
|
|
@ -36,6 +37,14 @@ void ConvertTFLQuantOpsToMlirQuantOps(FuncOp func) {
|
|||
q.getLoc(), q.output().getType(), q.input());
|
||||
q.output().replaceAllUsesWith(qcast);
|
||||
q.erase();
|
||||
} else if (auto q = llvm::dyn_cast<ConstOp>(op)) {
|
||||
auto value = q.value();
|
||||
auto type = q.getResult().getType();
|
||||
if (ConstantOp::isBuildableWith(value, type)) {
|
||||
auto c = b.create<ConstantOp>(q.getLoc(), q.value());
|
||||
q.output().replaceAllUsesWith(c);
|
||||
q.erase();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -518,6 +518,12 @@ ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) {
|
|||
return APInt(8, value, /*isSigned=*/true);
|
||||
});
|
||||
return DenseElementsAttr::get(new_dense_type, quantized_attr);
|
||||
} else if (width == 8) {
|
||||
// This can be a state tensor, or an actual constant tensor with
|
||||
// asymmetric range. For a state tensor, assigining correct quantization
|
||||
// parameters is sufficient, and for constants with asymmetric range it's
|
||||
// not correctly quantized by legacy quantizer so call the new Quantize.
|
||||
return Quantize(real_value, tensor_type);
|
||||
} else if (width == 16) {
|
||||
if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
|
||||
auto quantized_values =
|
||||
|
|
|
|||
|
|
@ -1989,14 +1989,14 @@ func @rfft2d(%arg0: tensor<10x20x10x30xf32>, %arg1: tensor<2xi32>) -> tensor<10x
|
|||
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
return %0 : tensor<10x20x10x30xcomplex<f32>>
|
||||
// CHECK-LABEL: rfft2d
|
||||
// CHECK: "tfl.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
// CHECK: "tfl.rfft2d"(%arg0, %arg1) : (tensor<10x20x10x30xf32>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @rfft2d_invalid(%arg0: tensor<10x20x10x30xf64>, %arg1: tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>> {
|
||||
%0 = "tf.RFFT2D"(%arg0, %arg1) : (tensor<10x20x10x30xf64>, tensor<2xi32>) -> tensor<10x20x10x30xcomplex<f64>>
|
||||
return %0 : tensor<10x20x10x30xcomplex<f64>>
|
||||
// CHECK-LABEL: rfft2d_invalid
|
||||
// CHECK-NOT: "tfl.RFFT2D"
|
||||
// CHECK-NOT: "tfl.rfft2d"
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -75,6 +75,26 @@ func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32,
|
|||
// CHECK-NEXT: return %[[conv]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConv2DPerChannelConst
|
||||
func @QuantizeConv2DPerChannelConst(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32, 1.5>>,
|
||||
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> {
|
||||
%bias = "tfl.pseudo_const"() {value = dense<1.000000e+00> : tensor<32xf32>} : () -> tensor<32xf32>
|
||||
%input = "tfl.dequantize"(%arg0) : (tensor<1x224x224x3x!quant.uniform<u8:f32, 1.5>>) -> tensor<1x224x224x3xf32>
|
||||
%weight = "tfl.dequantize"(%arg1) : (tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<32x3x3x3xf32>
|
||||
%conv = "tfl.conv_2d"(%input, %weight, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32,
|
||||
fused_activation_function = "NONE", padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32}
|
||||
: (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32>
|
||||
return %conv : tensor<1x112x112x32xf32>
|
||||
|
||||
// CHECK-NEXT: %[[cst:.*]] = constant dense<1.000000e+00> : tensor<32xf32>
|
||||
// CHECK-NEXT: %[[qbias:.*]] = "tfl.quantize"(%[[cst]]) {qtype = tensor<32x!quant.uniform<i32:f32:0, {1.500000e+00,3.000000e+00,4.500000e+00}>>, volatile}
|
||||
// CHECK-NEXT: %[[bias:.*]] = "tfl.dequantize"(%[[qbias]])
|
||||
// CHECK-NEXT: %[[in:.*]] = "tfl.dequantize"(%arg0)
|
||||
// CHECK-NEXT: %[[w:.*]] = "tfl.dequantize"(%arg1)
|
||||
// CHECK-NEXT: %[[conv:.*]] = "tfl.conv_2d"(%[[in]], %[[w]], %[[bias]])
|
||||
// CHECK-NEXT: return %[[conv]]
|
||||
}
|
||||
|
||||
// CHECK-LABEL: QuantizeConv2DPerChannels
|
||||
func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform<u8:f32:3, {1.0,2.0,3.0}>>,
|
||||
%arg1: tensor<32x3x3x3x!quant.uniform<u8<1:255>:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-quantize -tfl-numeric-verify -tfl-log-if-failed | FileCheck --check-prefix=DEBUG %s
|
||||
// RUN: tf-opt %s -tfl-quantize -tfl-legacy-quantize | FileCheck --check-prefix=LEGACY %s
|
||||
|
||||
// CHECK-LABEL: QuantizeFloatConst
|
||||
func @QuantizeFloatConst() -> tensor<2x2x!quant.uniform<u8:f32, 7.8431372549019615E-4:128>> {
|
||||
|
|
@ -356,3 +357,14 @@ func @NotQuantizeCustomTfOp(%arg0: tensor<128x128x!quant.uniform<u8:f32, 0.1:127
|
|||
// CHECK-NEXT: "tfl.yield"
|
||||
// CHECK-NEXT: }) {device = ""} : (tensor<128x128xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xi32>) -> tensor<128x128xf32>
|
||||
}
|
||||
|
||||
|
||||
// Checks that legacy path correctly handles asymmetric quantized values.
|
||||
// LEGACY-LABEL: CheckLegacyQuantizeAdd
|
||||
func @CheckLegacyQuantizeAdd() -> tensor<1x2x!quant.uniform<i8:f32, 0.0078431372549019607:-128>> {
|
||||
%cst = constant dense<[[1.000000e+00, 2.000000e+00]]> : tensor<1x2xf32>
|
||||
%0 = "tfl.quantize"(%cst) {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.0078431372549019607:-128>>, volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<i8:f32, 0.0078431372549019607:-128>>
|
||||
return %0 : tensor<1x2x!quant.uniform<i8:f32, 0.0078431372549019607:-128>>
|
||||
|
||||
// LEGACY: "tfl.pseudo_qconst"() {qtype = tensor<1x2x!quant.uniform<i8:f32, 0.0078431372549019607:-128>>, value = dense<{{\[\[}}-1, 127]]> : tensor<1x2xi8>}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -374,7 +374,10 @@ void PrepareQuantizePass::runOnFunction() {
|
|||
patterns_1.insert<PrepareLstmOutputScale<UnidirectionalSequenceLSTMOp>>(
|
||||
ctx);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns_1));
|
||||
(void)applyPatternsAndFoldGreedily(
|
||||
func, std::move(patterns_1),
|
||||
// TODO(fengliuai): Fix the logic to work without this flag
|
||||
/*useTopDownTraversal=*/false);
|
||||
|
||||
// During the legalization, unsigned quantized type is used, so we have to
|
||||
// convert all of them to signed.
|
||||
|
|
@ -399,7 +402,8 @@ void PrepareQuantizePass::runOnFunction() {
|
|||
ctx, quant_specs_);
|
||||
patterns_2.insert<ConvertSvdfStatsToQDQs>(ctx, quant_specs_);
|
||||
}
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns_2),
|
||||
(void)applyPatternsAndFoldGreedily(
|
||||
func, std::move(patterns_2),
|
||||
// TODO(fengliuai): Fix the logic to work without this flag
|
||||
/*useTopDownTraversal=*/false);
|
||||
|
||||
|
|
|
|||
|
|
@ -66,6 +66,13 @@ static llvm::cl::opt<bool> enable_log_if_failed(
|
|||
"tolerance. Valid when `-tfl-numeric-verify` is set."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::opt<bool> enable_legacy_quantize(
|
||||
"tfl-legacy-quantize", llvm::cl::value_desc("bool"),
|
||||
llvm::cl::desc("Use legacy quantize mode in test. Valid when"
|
||||
"`-tfl-legacy-quantize` is set."),
|
||||
llvm::cl::init(false));
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
|
||||
|
|
@ -87,40 +94,31 @@ struct TFLFullQuantization
|
|||
static bool AllowHybridResult() { return false; }
|
||||
};
|
||||
|
||||
struct LegacyQuantizeConstPattern : public OpRewritePattern<QuantizeOp> {
|
||||
// This pattern should be applied before existing quantize pattern in
|
||||
// `quantize_patterns.td`, so the benefit is set to some value larger than 1.
|
||||
explicit LegacyQuantizeConstPattern(MLIRContext* context)
|
||||
: OpRewritePattern<QuantizeOp>(context, /*benefit=*/10) {}
|
||||
LogicalResult matchAndRewrite(QuantizeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
DenseFPElementsAttr attr;
|
||||
if (matchPattern(op.input(), m_Constant(&attr))) {
|
||||
auto qtype = op.qtypeAttr();
|
||||
if (auto quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue())) {
|
||||
rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct QuantizeConstPattern : public OpRewritePattern<QuantizeOp> {
|
||||
explicit QuantizeConstPattern(MLIRContext* context)
|
||||
: OpRewritePattern<QuantizeOp>(context) {}
|
||||
explicit QuantizeConstPattern(MLIRContext* context, bool legacy_float_scale)
|
||||
: OpRewritePattern<QuantizeOp>(context),
|
||||
legacy_float_scale(legacy_float_scale) {}
|
||||
LogicalResult matchAndRewrite(QuantizeOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
DenseFPElementsAttr attr;
|
||||
if (matchPattern(op.input(), m_Constant(&attr))) {
|
||||
auto qtype = op.qtypeAttr();
|
||||
if (auto quantized_attr = quant::Quantize(attr, qtype.getValue())) {
|
||||
Attribute quantized_attr;
|
||||
if (legacy_float_scale) {
|
||||
quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue());
|
||||
} else {
|
||||
quantized_attr = quant::Quantize(attr, qtype.getValue());
|
||||
}
|
||||
if (quantized_attr) {
|
||||
rewriter.replaceOpWithNewOp<QConstOp>(op, qtype, quantized_attr);
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
private:
|
||||
bool legacy_float_scale;
|
||||
};
|
||||
|
||||
// Applies quantization on the model in TFL dialect.
|
||||
|
|
@ -155,10 +153,8 @@ void QuantizePass::runOnFunction() {
|
|||
// Constant quantization is a lossy transformation, so they are applied only
|
||||
// after all the other patterns have been aplied.
|
||||
OwningRewritePatternList patterns_2(&getContext());
|
||||
if (legacy_float_scale) {
|
||||
patterns_2.insert<LegacyQuantizeConstPattern>(ctx);
|
||||
}
|
||||
patterns_2.insert<QuantizeConstPattern>(ctx);
|
||||
patterns_2.insert<QuantizeConstPattern>(
|
||||
ctx, legacy_float_scale || enable_legacy_quantize);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns_2));
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -537,6 +537,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
|
|||
return signalPassFailure();
|
||||
}
|
||||
});
|
||||
// No constant should have an "_xla_outside_compilation" attribute left.
|
||||
// TODO(kfranko): We likely should revisit where is the best place for this
|
||||
// logic to live (canonicalization pattern?).
|
||||
module.walk(
|
||||
[&](TF::ConstOp op) { op->removeAttr("_xla_outside_compilation"); });
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -3503,15 +3503,9 @@ class SavedModelSignatureDefImporterLite {
|
|||
static StatusOr<mlir::OwningModuleRef> Convert(
|
||||
SavedModelMLIRImportInput& input, absl::Span<std::string> exported_names,
|
||||
mlir::MLIRContext* context, bool import_restore = true) {
|
||||
LoadImporterDialects(*context);
|
||||
SavedModelSignatureDefImporterLite importer(input, exported_names, context,
|
||||
import_restore);
|
||||
TF_ASSIGN_OR_RETURN(auto module, importer.ConvertSignatures());
|
||||
|
||||
SortSavedModelModule(*module);
|
||||
MarkSavedModelFunctionVisibility(*module);
|
||||
|
||||
return module;
|
||||
return importer.ConvertSignatures();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -3789,6 +3783,8 @@ SavedModelSignatureDefImporterLite::ParseInputArrays(
|
|||
|
||||
StatusOr<mlir::OwningModuleRef>
|
||||
SavedModelSignatureDefImporterLite::ConvertSignatures() {
|
||||
LoadImporterDialects(*module_->getContext());
|
||||
|
||||
const auto& signatures = input_.meta_graph_def().signature_def();
|
||||
PopulateTfVersions(module_.get(),
|
||||
input_.meta_graph_def().graph_def().versions());
|
||||
|
|
|
|||
|
|
@ -2704,9 +2704,10 @@ XlaOp XlaBuilder::AllGather(XlaOp operand, int64 all_gather_dimension,
|
|||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape inferred_shape,
|
||||
ShapeInference::InferAllGatherShape(
|
||||
*operand_shape, all_gather_dimension, shard_count));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape inferred_shape,
|
||||
ShapeInference::InferAllGatherShape({operand_shape},
|
||||
all_gather_dimension, shard_count));
|
||||
if (layout) {
|
||||
*inferred_shape.mutable_layout() = *layout;
|
||||
instr.set_constrain_layout(true);
|
||||
|
|
@ -4291,6 +4292,20 @@ XlaOp ReduceWindowWithGeneralPadding(
|
|||
padding);
|
||||
}
|
||||
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
CHECK(!operands.empty());
|
||||
return operands[0].builder()->ReduceWindowWithGeneralPadding(
|
||||
operands, init_values, computation, window_dimensions, window_strides,
|
||||
base_dilations, window_dilations, padding);
|
||||
}
|
||||
|
||||
XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
|
||||
int64 shard_count,
|
||||
absl::Span<const ReplicaGroup> replica_groups,
|
||||
|
|
|
|||
|
|
@ -1273,6 +1273,15 @@ class XlaBuilder {
|
|||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
friend XlaOp ReduceWindowWithGeneralPadding(
|
||||
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
|
||||
friend XlaOp CrossReplicaSum(XlaOp operand,
|
||||
absl::Span<const ReplicaGroup> replica_groups);
|
||||
friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension,
|
||||
|
|
@ -2161,6 +2170,14 @@ XlaOp ReduceWindowWithGeneralPadding(
|
|||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
|
||||
// Returns the sum of the operand value within each subgroup of replicas. All
|
||||
// replicas supply one input to the sum and all replicas receive the resulting
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||
platform->GetExecutor(config));
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, asynchronous,
|
||||
/*allow_event_reuse=*/false);
|
||||
/*allow_event_reuse=*/false, /*use_callback_stream=*/false);
|
||||
auto device = absl::make_unique<CpuDevice>(i, std::move(device_state));
|
||||
devices.push_back(std::move(device));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ StatusOr<std::vector<std::unique_ptr<LocalDeviceState>>> BuildLocalDeviceStates(
|
|||
addressable_devices.push_back(absl::make_unique<LocalDeviceState>(
|
||||
executor, xla_client, LocalDeviceState::kComputeSynchronized,
|
||||
asynchronous,
|
||||
/*allow_event_reuse=*/true));
|
||||
/*allow_event_reuse=*/true, /*use_callback_stream=*/true));
|
||||
}
|
||||
return std::move(addressable_devices);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() {
|
|||
client->backend().stream_executor(0).ValueOrDie();
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/false,
|
||||
/*allow_event_reuse=*/false);
|
||||
/*allow_event_reuse=*/false, /*use_callback_stream=*/false);
|
||||
auto device =
|
||||
absl::make_unique<InterpreterDevice>(0, std::move(device_state));
|
||||
devices.push_back(std::move(device));
|
||||
|
|
|
|||
|
|
@ -31,7 +31,8 @@ namespace xla {
|
|||
LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
|
||||
LocalClient* client,
|
||||
AllocationModel allocation_model,
|
||||
bool asynchronous, bool allow_event_reuse)
|
||||
bool asynchronous, bool allow_event_reuse,
|
||||
bool use_callback_stream)
|
||||
: allocation_model_(allocation_model),
|
||||
event_pool_(allow_event_reuse),
|
||||
compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
|
||||
|
|
@ -40,28 +41,30 @@ LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
|
|||
prng_seed_generator_(prng_seed_device_()),
|
||||
prng_seed_distribution_(std::numeric_limits<int>::min(),
|
||||
std::numeric_limits<int>::max()) {
|
||||
compute_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
callback_stream_ = absl::make_unique<se::Stream>(executor);
|
||||
compute_stream_ = std::make_unique<se::Stream>(executor);
|
||||
host_to_device_stream_ = std::make_unique<se::Stream>(executor);
|
||||
compute_stream_->Init();
|
||||
host_to_device_stream_->Init();
|
||||
callback_stream_->Init();
|
||||
if (use_callback_stream) {
|
||||
callback_stream_ = std::make_unique<se::Stream>(executor);
|
||||
callback_stream_->Init();
|
||||
}
|
||||
device_to_host_streams_.reserve(kNumDeviceToHostStreams);
|
||||
for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
|
||||
auto stream = absl::make_unique<se::Stream>(executor);
|
||||
auto stream = std::make_unique<se::Stream>(executor);
|
||||
stream->Init();
|
||||
device_to_host_streams_.push_back(std::move(stream));
|
||||
}
|
||||
device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
|
||||
for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
|
||||
auto stream = absl::make_unique<se::Stream>(executor);
|
||||
auto stream = std::make_unique<se::Stream>(executor);
|
||||
stream->Init();
|
||||
device_to_device_streams_.push_back(std::move(stream));
|
||||
}
|
||||
execute_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_execute");
|
||||
callback_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_callback");
|
||||
execute_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_execute");
|
||||
callback_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
|
||||
"py_xla_callback");
|
||||
}
|
||||
|
||||
LocalDeviceState::~LocalDeviceState() {
|
||||
|
|
@ -79,7 +82,9 @@ Status LocalDeviceState::SynchronizeAllActivity() {
|
|||
// stopped, also block on the compute stream. If SynchronizeAllActivity is
|
||||
// fixed, we could remove the BlockHostUntilDone call.
|
||||
status.Update(compute_stream_->BlockHostUntilDone());
|
||||
status.Update(callback_stream_->BlockHostUntilDone());
|
||||
if (callback_stream_) {
|
||||
status.Update(callback_stream_->BlockHostUntilDone());
|
||||
}
|
||||
bool ok = compute_stream_->parent()->SynchronizeAllActivity();
|
||||
if (!ok) {
|
||||
status.Update(Unknown("SynchronizeAllActivity failed."));
|
||||
|
|
@ -97,9 +102,13 @@ Status LocalDeviceState::ThenMemcpyDeviceToDevice(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
void LocalDeviceState::ThenExecuteOnCallbackThread(
|
||||
void LocalDeviceState::ThenExecuteCallback(
|
||||
se::Stream* stream, std::function<void()> callback) const {
|
||||
tensorflow::profiler::TraceMe traceme("ThenExecuteOnCallbackThread");
|
||||
tensorflow::profiler::TraceMe traceme("ThenExecuteCallback");
|
||||
if (callback_stream_ && callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
stream = callback_stream_.get();
|
||||
}
|
||||
stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable {
|
||||
callback_thread_->Schedule(std::move(callback));
|
||||
});
|
||||
|
|
@ -124,7 +133,7 @@ se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
|
|||
std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (usage_stream_pool_.empty()) {
|
||||
auto stream = absl::make_unique<se::Stream>(compute_stream_->parent());
|
||||
auto stream = std::make_unique<se::Stream>(compute_stream_->parent());
|
||||
stream->Init();
|
||||
return stream;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class LocalDeviceState {
|
|||
// each execution or transfer. This is intended for debugging only.
|
||||
LocalDeviceState(se::StreamExecutor* executor, LocalClient* client,
|
||||
AllocationModel allocation_model, bool asynchronous,
|
||||
bool allow_event_reuse);
|
||||
bool allow_event_reuse, bool use_callback_stream);
|
||||
virtual ~LocalDeviceState();
|
||||
|
||||
se::StreamExecutor* executor() const { return executor_; }
|
||||
|
|
@ -107,7 +107,6 @@ class LocalDeviceState {
|
|||
se::Stream* host_to_device_stream() const {
|
||||
return host_to_device_stream_.get();
|
||||
}
|
||||
se::Stream* callback_stream() const { return callback_stream_.get(); }
|
||||
|
||||
// Returns a device to host stream. Allocates streams in a round-robin fashion
|
||||
// amongst the available streams.
|
||||
|
|
@ -132,12 +131,17 @@ class LocalDeviceState {
|
|||
|
||||
WorkerThread* execute_thread() const { return execute_thread_.get(); }
|
||||
|
||||
// Enqueues a host callback on 'stream', to be executed by callback_thread_.
|
||||
// ThenDoHostCallback is often constrained in what it can do, in particular,
|
||||
// on GPU the callback runs on a thread belonging to the GPU runtime and
|
||||
// cannot perform GPU operations itself.
|
||||
void ThenExecuteOnCallbackThread(se::Stream* stream,
|
||||
std::function<void()> callback) const;
|
||||
// Enqueues a host callback on 'stream'. `stream` may, but need not, wait for
|
||||
// `callback` to complete. It is safe to call runtime methods from the
|
||||
// callback.
|
||||
// This API differs from ThenDoHostCallback in two ways:
|
||||
// a) ThenDoHostCallback is often constrained in what it can do, in
|
||||
// particular, on GPU the callback runs on a thread belonging to the GPU
|
||||
// runtime and cannot perform GPU operations itself. On GPU, callbacks
|
||||
// execute in a separate thread.
|
||||
// b) ThenDoHostCallback waits for the callback to complete.
|
||||
void ThenExecuteCallback(se::Stream* stream,
|
||||
std::function<void()> callback) const;
|
||||
|
||||
// Helpers for releasing values on a worker thread at the tail of a stream on
|
||||
// a worker thread. Copies `object`, and destroys the copy when the tail of
|
||||
|
|
@ -147,12 +151,8 @@ class LocalDeviceState {
|
|||
// (e.g., GPU objects).
|
||||
template <typename T>
|
||||
void ThenRelease(se::Stream* stream, T&& object) const {
|
||||
if (callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
}
|
||||
ThenExecuteOnCallbackThread(
|
||||
callback_stream_.get(),
|
||||
[object = std::forward<T>(object)]() { /* releases object */ });
|
||||
ThenExecuteCallback(
|
||||
stream, [object = std::forward<T>(object)]() { /* releases object */ });
|
||||
}
|
||||
|
||||
Semaphore& compute_semaphore() { return compute_semaphore_; }
|
||||
|
|
|
|||
|
|
@ -833,10 +833,8 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
|||
local_device, std::move(device_buffer), event,
|
||||
local_device->host_to_device_stream()));
|
||||
|
||||
local_device->callback_stream()->ThenWaitFor(
|
||||
local_device->host_to_device_stream());
|
||||
local_device->ThenExecuteOnCallbackThread(
|
||||
local_device->callback_stream(),
|
||||
local_device->ThenExecuteCallback(
|
||||
local_device->host_to_device_stream(),
|
||||
[staging_buffer{std::move(staging_buffer)},
|
||||
on_done_with_host_buffer{std::move(on_done_with_host_buffer)}]() {
|
||||
if (on_done_with_host_buffer) {
|
||||
|
|
@ -1128,7 +1126,7 @@ PjRtStreamExecutorBuffer::Release(bool wait_for_operations_to_complete) {
|
|||
}
|
||||
if (block_stream != nullptr) {
|
||||
se::Stream* block_stream_ptr = block_stream.release();
|
||||
local_device_state->ThenExecuteOnCallbackThread(
|
||||
local_device_state->ThenExecuteCallback(
|
||||
block_stream_ptr,
|
||||
[device_buffer, block_stream_ptr, local_device_state]() {
|
||||
local_device_state->ReturnStreamToPool(
|
||||
|
|
@ -1997,11 +1995,9 @@ PjRtStreamExecutorExecutable::ExecuteHelper(
|
|||
}
|
||||
|
||||
if (!compute_callbacks.empty()) {
|
||||
device_state->callback_stream()->ThenWaitFor(stream);
|
||||
device_state->ThenExecuteOnCallbackThread(
|
||||
device_state->callback_stream(),
|
||||
[callbacks{std::move(compute_callbacks)},
|
||||
buffers_to_release{std::move(buffers_to_release)}]() {
|
||||
device_state->ThenExecuteCallback(
|
||||
stream, [callbacks{std::move(compute_callbacks)},
|
||||
buffers_to_release{std::move(buffers_to_release)}]() {
|
||||
for (auto& fn : callbacks) {
|
||||
fn();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -61,7 +61,8 @@ TpuDeviceState::TpuDeviceState(se::StreamExecutor* executor,
|
|||
LocalClient* client, bool asynchronous)
|
||||
: LocalDeviceState(executor, client, LocalDeviceState::kAsynchronous,
|
||||
asynchronous,
|
||||
/*allow_event_reuse=*/false) {}
|
||||
/*allow_event_reuse=*/false,
|
||||
/*use_callback_stream=*/true) {}
|
||||
|
||||
Status TpuDeviceState::ThenMemcpyDeviceToDevice(
|
||||
se::Stream* transfer_stream, se::Stream* dst_stream,
|
||||
|
|
|
|||
|
|
@ -262,6 +262,7 @@ cc_library(
|
|||
visibility = ["//visibility:private"],
|
||||
deps = [
|
||||
":py_client",
|
||||
":python_ref_manager",
|
||||
":pytree",
|
||||
":types",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/python/py_buffer.h"
|
||||
#include "tensorflow/compiler/xla/python/py_executable.h"
|
||||
#include "tensorflow/compiler/xla/python/py_values.h"
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
#include "tensorflow/compiler/xla/python/pytree.h"
|
||||
#include "tensorflow/compiler/xla/python/types.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
|
@ -131,13 +132,23 @@ struct GlobalJitState {
|
|||
GlobalJitState& global_state = *new GlobalJitState();
|
||||
|
||||
struct ThreadLocalJitState {
|
||||
~ThreadLocalJitState() {
|
||||
if (extra_jit_context) {
|
||||
// We likely do not hold the GIL, so we hand the Python object to the
|
||||
// global reference manager to destroy.
|
||||
py::object o = std::move(*extra_jit_context);
|
||||
xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1));
|
||||
extra_jit_context = absl::nullopt;
|
||||
}
|
||||
}
|
||||
absl::optional<bool> disable_jit;
|
||||
absl::optional<bool> enable_x64;
|
||||
absl::optional<py::object> extra_jit_context;
|
||||
};
|
||||
|
||||
thread_local ThreadLocalJitState& thread_local_state =
|
||||
*new ThreadLocalJitState();
|
||||
// TODO(phawkins): Google style guide forbids thread-local values with
|
||||
// non-trivial destructors.
|
||||
ABSL_CONST_INIT thread_local ThreadLocalJitState thread_local_state; // NOLINT
|
||||
|
||||
bool JitIsDisabled() {
|
||||
return thread_local_state.disable_jit.value_or(global_state.disable_jit);
|
||||
|
|
|
|||
|
|
@ -238,11 +238,17 @@ void BuildOpsSubmodule(py::module* m) {
|
|||
py::arg("computation"), py::arg("dimensions_to_reduce"));
|
||||
ops.def("ReducePrecision", &ReducePrecision, py::arg("operand"),
|
||||
py::arg("exponent_bits"), py::arg("mantissa_bits"));
|
||||
ops.def("ReduceWindowWithGeneralPadding", &ReduceWindowWithGeneralPadding,
|
||||
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
|
||||
py::arg("window_dimensions"), py::arg("window_strides"),
|
||||
py::arg("base_dilations"), py::arg("window_dilations"),
|
||||
py::arg("padding"));
|
||||
ops.def(
|
||||
"ReduceWindowWithGeneralPadding",
|
||||
static_cast<XlaOp (*)(XlaOp, XlaOp, const XlaComputation&,
|
||||
absl::Span<const int64>, absl::Span<const int64>,
|
||||
absl::Span<const int64>, absl::Span<const int64>,
|
||||
absl::Span<const std::pair<int64, int64>>)>(
|
||||
&ReduceWindowWithGeneralPadding),
|
||||
py::arg("operand"), py::arg("init_value"), py::arg("computation"),
|
||||
py::arg("window_dimensions"), py::arg("window_strides"),
|
||||
py::arg("base_dilations"), py::arg("window_dilations"),
|
||||
py::arg("padding"));
|
||||
ops.def("RemoveDynamicDimension", &RemoveDynamicDimension, py::arg("operand"),
|
||||
py::arg("dimension"));
|
||||
ops.def("ReplicaId", &ReplicaId, py::arg("builder"));
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
|
|||
platform->GetExecutor(config));
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/true,
|
||||
/*allow_event_reuse=*/false);
|
||||
/*allow_event_reuse=*/false, /*use_callback_stream=*/false);
|
||||
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
devices.push_back(absl::make_unique<CpuDevice>(0, std::move(device_state)));
|
||||
|
|
|
|||
|
|
@ -237,6 +237,9 @@ cc_library(
|
|||
"//tensorflow:linux_ppc64le": [
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
],
|
||||
"//tensorflow:macos_arm64": [
|
||||
"@llvm-project//llvm:AArch64CodeGen", # fixdeps: keep
|
||||
],
|
||||
"//conditions:default": [
|
||||
],
|
||||
}) + if_llvm_system_z_available([
|
||||
|
|
|
|||
|
|
@ -208,24 +208,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "thunk_emitter",
|
||||
srcs = ["thunk_emitter.cc"],
|
||||
hdrs = ["thunk_emitter.h"],
|
||||
deps = [
|
||||
":backend_configs_cc",
|
||||
":buffer_allocations",
|
||||
":gpu_constants",
|
||||
":gpu_executable",
|
||||
":ir_emission_utils",
|
||||
":thunk",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gpu_device_info",
|
||||
hdrs = ["gpu_device_info.h"],
|
||||
|
|
@ -260,7 +242,6 @@ cc_library(
|
|||
":parallel_loop_emitter",
|
||||
":target_util",
|
||||
":thunk",
|
||||
":thunk_emitter",
|
||||
"//tensorflow/compiler/mlir:name_utils",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
|
|
|
|||
|
|
@ -832,12 +832,6 @@ Status IrEmitterUnnested::EmitConditionalFromMlir(MlirEmitterInput mlir_input) {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
|
||||
AddThunkToThunkSequence(
|
||||
BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
|
||||
return IrEmitter::HandleConvolution(convolution);
|
||||
}
|
||||
|
||||
// Input = {dynamic array(with dynamic dimension meta data at the end)}
|
||||
// Output = {static array, dynamic_dim0, dynamic_dim1}
|
||||
Status IrEmitterUnnested::EmitPadToStaticFromMlir(MlirEmitterInput mlir_input) {
|
||||
|
|
@ -3552,34 +3546,6 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
|||
std::string(kernel->getName()));
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
const HloInstruction* inst, bool implements_whole_instruction) {
|
||||
std::vector<HloBufferSlice> hlo_slices =
|
||||
GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment());
|
||||
|
||||
std::vector<BufferSlice*> slice_ptrs;
|
||||
slice_ptrs.reserve(hlo_slices.size());
|
||||
for (auto& slice : hlo_slices) {
|
||||
slice_ptrs.push_back(&slice);
|
||||
}
|
||||
|
||||
return BuildKernelThunkFromBufferSlices(
|
||||
inst->name(),
|
||||
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
|
||||
slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) {
|
||||
const HloBufferSlice* hlo_buffer_slice =
|
||||
static_cast<const HloBufferSlice*>(slice);
|
||||
const HloInstruction* instr = hlo_buffer_slice->instr;
|
||||
const ShapeIndex& index = hlo_buffer_slice->hlo_index;
|
||||
VLOG(3) << "Buffer for " << instr->ToString() << " at "
|
||||
<< index.ToString() << " is found in slice "
|
||||
<< hlo_buffer_slice->buffer_slice.ToString() << " at GTE index "
|
||||
<< hlo_buffer_slice->gte_index.ToString();
|
||||
|
||||
bindings_.BindHloToIrValue(*instr, value, index);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlirImpl(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const MlirBufferSlice> slices,
|
||||
|
|
@ -5813,7 +5779,10 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
|
|||
|
||||
Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
|
||||
CHECK(hlo);
|
||||
Thunk::ThunkInfo info;
|
||||
info.profile_annotation = absl::StrFormat(
|
||||
"Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), hlo->GetModule()->name());
|
||||
if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
|
||||
info.profile_index.emplace(
|
||||
static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
|
|
@ -93,8 +92,7 @@ struct MlirEmitterContext {
|
|||
// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not
|
||||
// really an IrEmitter, but is more an "IR generator generator".)
|
||||
//
|
||||
class IrEmitterUnnested : public IrEmitter,
|
||||
private ThunkEmitter::EmissionContext {
|
||||
class IrEmitterUnnested : public IrEmitter {
|
||||
public:
|
||||
struct ThreadIdInfo {
|
||||
// Raw thread id.
|
||||
|
|
@ -110,7 +108,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
llvm::Value* lane_id;
|
||||
};
|
||||
|
||||
absl::string_view platform_name() const override {
|
||||
absl::string_view platform_name() const {
|
||||
return ir_emitter_context_->platform_name();
|
||||
}
|
||||
|
||||
|
|
@ -160,7 +158,6 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
|
||||
Status HandleConditional(HloInstruction* conditional) override;
|
||||
Status EmitConditionalFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||
Status EmitCustomCallFromMlir(MlirEmitterInput input);
|
||||
Status EmitConvolutionThunkFromMlir(MlirEmitterInput input);
|
||||
|
|
@ -235,7 +232,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
IrEmitterContext* ir_emitter_context);
|
||||
|
||||
// Add a owning Thunk object to the thunk sequence.
|
||||
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) override {
|
||||
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
|
||||
thunk_sequence_.emplace_back(std::move(thunk));
|
||||
}
|
||||
|
||||
|
|
@ -333,7 +330,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
|
||||
// A convenient helper for calling BufferAssignment::GetUniqueSlice.
|
||||
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const override {
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const {
|
||||
return ir_emitter_context_->buffer_assignment().GetUniqueSlice(&hlo, index);
|
||||
}
|
||||
|
||||
|
|
@ -344,7 +341,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
|
||||
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(mlir::Value v);
|
||||
|
||||
int64 ByteSizeOf(const Shape& shape) const override {
|
||||
int64 ByteSizeOf(const Shape& shape) const {
|
||||
return llvm_ir::ByteSizeOf(
|
||||
shape, ir_emitter_context_->llvm_module()->getDataLayout());
|
||||
}
|
||||
|
|
@ -637,14 +634,6 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
std::function<void(const BufferSlice*, llvm::Value*)>
|
||||
bind_slice_to_ir_value);
|
||||
|
||||
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
|
||||
// caller needs to make sure `inst` outlives the lifetime of the returned
|
||||
// Thunk object. 'implements_whole_instruction' specifies whether this
|
||||
// KernelThunk implements the whole 'inst' HloInstruction. In some cases
|
||||
// 'inst' will be implemented by a sequence of Thunks.
|
||||
std::unique_ptr<KernelThunk> BuildKernelThunk(
|
||||
const HloInstruction* inst, bool implements_whole_instruction);
|
||||
|
||||
std::unique_ptr<KernelThunk> BuildKernelThunkForMlirImpl(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const MlirBufferSlice> slices,
|
||||
|
|
@ -726,7 +715,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||
// Returns the last generated thunk.
|
||||
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
|
||||
|
||||
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const override;
|
||||
Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
|
||||
|
||||
Status AssertNonDeterminismIsOkay(const string& op_name);
|
||||
|
||||
|
|
|
|||
|
|
@ -34,12 +34,6 @@
|
|||
// CHECK: store float %[[VAL_32]], float* %[[VAL_34]], align 4
|
||||
// CHECK: br label %[[VAL_30]]
|
||||
// CHECK: }
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
|
||||
// CHECK: ; Function Attrs: nofree nosync nounwind willreturn
|
||||
// CHECK: declare void @llvm.assume(i1 noundef) #1
|
||||
|
||||
// CHECK: define void @fusion__1(i8* noalias align 16 dereferenceable(8192) %[[VAL_0:.*]], i8* noalias align 64 dereferenceable(256) %[[VAL_1:.*]], i8* noalias align 64 dereferenceable(256) %[[VAL_2:.*]], i8* noalias align 16 dereferenceable(4) %[[VAL_3:.*]], i8* noalias align 16 dereferenceable(4) %[[VAL_4:.*]]) {
|
||||
// CHECK: entry:
|
||||
|
|
@ -787,8 +781,6 @@
|
|||
// CHECK: %[[VAL_559:.*]] = extractvalue { i32, i1 } %[[VAL_557]], 1
|
||||
// CHECK: br i1 %[[VAL_559]], label %[[VAL_549]], label %[[VAL_554]]
|
||||
// CHECK: }
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #0
|
||||
|
||||
// CHECK: define internal void @region_1_4(float* dereferenceable(4) %[[VAL_0:.*]], float* dereferenceable(4) %[[VAL_1:.*]], float* dereferenceable(4) %[[VAL_2:.*]]) {
|
||||
// CHECK: entry:
|
||||
|
|
|
|||
|
|
@ -55,12 +55,6 @@ compare {
|
|||
// CHECK: store float %[[VAL_32]], float* %[[VAL_34]], align 4
|
||||
// CHECK: br label %[[VAL_25]]
|
||||
// CHECK: }
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
|
||||
// CHECK: ; Function Attrs: nofree nosync nounwind willreturn
|
||||
// CHECK: declare void @llvm.assume(i1 noundef) #1
|
||||
|
||||
// CHECK: define internal void @region_0_4(float* dereferenceable(4) %[[VAL_0:.*]], float* dereferenceable(4) %[[VAL_1:.*]], i8* dereferenceable(1) %[[VAL_2:.*]]) {
|
||||
// CHECK: entry:
|
||||
|
|
@ -244,12 +238,6 @@ compare {
|
|||
// CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4
|
||||
// CHECK: br label %[[VAL_31]]
|
||||
// CHECK: }
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
|
||||
// CHECK: ; Function Attrs: nounwind readnone
|
||||
// CHECK: declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
|
||||
// CHECK: ; Function Attrs: nofree nosync nounwind willreturn
|
||||
// CHECK: declare void @llvm.assume(i1 noundef) #1
|
||||
|
||||
// CHECK: define internal void @region_0_6(i32* dereferenceable(4) %[[VAL_0:.*]], i32* dereferenceable(4) %[[VAL_1:.*]], float* dereferenceable(4) %[[VAL_2:.*]], float* dereferenceable(4) %[[VAL_3:.*]], i8* dereferenceable(1) %[[VAL_4:.*]]) {
|
||||
// CHECK: entry:
|
||||
|
|
|
|||
|
|
@ -1,76 +0,0 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk_emitter.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
const HloInstruction* inst) {
|
||||
GpuGemmConfig config = GetGpuGemmConfig(inst);
|
||||
const HloInstruction* lhs = inst->operand(0);
|
||||
const HloInstruction* rhs = inst->operand(1);
|
||||
|
||||
// The bias is passed inside the output buffer. If those buffers are shared
|
||||
// we can just use it, otherwise copy the bias values into the output buffer
|
||||
// first.
|
||||
if (config.backend_config.beta() != 0.0) {
|
||||
const HloInstruction* bias = inst->operand(2);
|
||||
CHECK_EQ(bias->shape(), inst->shape());
|
||||
if (GetAllocationSlice(*bias) != GetAllocationSlice(*inst)) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_buffer=*/GetAllocationSlice(*bias),
|
||||
/*destination_buffer=*/GetAllocationSlice(*inst),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(inst->shape())));
|
||||
thunks.push_back(absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst), std::move(config),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/false));
|
||||
return absl::make_unique<SequentialThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(thunks));
|
||||
}
|
||||
}
|
||||
|
||||
return absl::make_unique<GemmThunk>(
|
||||
context_->GetThunkInfo(inst), std::move(config),
|
||||
GetAllocationSlice(*lhs), // The buffer assigned to LHS.
|
||||
GetAllocationSlice(*rhs), // The buffer assigned to RHS.
|
||||
GetAllocationSlice(*inst), // The output buffer.
|
||||
/*implements_whole_instruction=*/true);
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
CHECK(hlo);
|
||||
Thunk::ThunkInfo info;
|
||||
info.profile_annotation = absl::StrFormat(
|
||||
"Thunk:#hlo_op=%s,hlo_module=%s#", hlo->name(), hlo->GetModule()->name());
|
||||
return info;
|
||||
}
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// Implements handling of GPU execution for HLO operations that are handed off
|
||||
// to specialized thunks that do not require code generation. Intended to be
|
||||
// mixed into GPU emitters.
|
||||
class ThunkEmitter {
|
||||
public:
|
||||
class EmissionContext {
|
||||
public:
|
||||
virtual void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) = 0;
|
||||
virtual StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const = 0;
|
||||
virtual int64 ByteSizeOf(const Shape& shape) const = 0;
|
||||
virtual absl::string_view platform_name() const = 0;
|
||||
virtual Thunk::ThunkInfo GetThunkInfo(const HloInstruction* hlo) const;
|
||||
|
||||
virtual ~EmissionContext() = default;
|
||||
};
|
||||
|
||||
explicit ThunkEmitter(EmissionContext* context) : context_(context) {}
|
||||
|
||||
Status HandleTriangularSolve(HloInstruction* hlo);
|
||||
|
||||
private:
|
||||
EmissionContext* context_;
|
||||
|
||||
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) {
|
||||
return context_->AddThunkToThunkSequence(std::move(thunk));
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> MaybeGetAllocationSlice(
|
||||
const HloInstruction& hlo, const ShapeIndex& index) const {
|
||||
return context_->MaybeGetAllocationSlice(hlo, index);
|
||||
}
|
||||
|
||||
int64 ByteSizeOf(const Shape& shape) { return context_->ByteSizeOf(shape); }
|
||||
|
||||
absl::string_view platform_name() const { return context_->platform_name(); }
|
||||
|
||||
BufferAllocation::Slice GetAllocationSlice(
|
||||
const HloInstruction& hlo, const ShapeIndex& index = {}) const {
|
||||
return MaybeGetAllocationSlice(hlo, index).ValueOrDie();
|
||||
}
|
||||
|
||||
// Returns a CholeskyThunk that calls cuSolver to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
|
||||
|
||||
// Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
|
||||
// to make sure `inst` outlives the lifetime of the returned Thunk object.
|
||||
std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_EMITTER_H_
|
||||
|
|
@ -417,11 +417,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||
|
||||
TF_RET_CHECK(proto.dimensions_size() == 1)
|
||||
<< "AllGather cannot have more than 1 all-gather dimensions";
|
||||
TF_RET_CHECK(all_operands().size() == 1)
|
||||
<< "AllGather must have a single operand";
|
||||
int64 all_gather_dimension = proto.dimensions(0);
|
||||
instruction = CreateAllGather(
|
||||
shape, operands(0), all_gather_dimension,
|
||||
shape, all_operands(), all_gather_dimension,
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()),
|
||||
proto.constrain_layout(), channel_id, proto.use_global_device_ids());
|
||||
|
|
@ -1041,11 +1039,12 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
|
|||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
|
||||
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids) {
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
int64 all_gather_dimension, const std::vector<ReplicaGroup>& replica_groups,
|
||||
bool constrain_layout, const absl::optional<int64>& channel_id,
|
||||
bool use_global_device_ids) {
|
||||
return absl::make_unique<HloAllGatherInstruction>(
|
||||
shape, operand, all_gather_dimension, replica_groups, constrain_layout,
|
||||
shape, operands, all_gather_dimension, replica_groups, constrain_layout,
|
||||
channel_id, use_global_device_ids);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -674,7 +674,8 @@ class HloInstruction {
|
|||
// except that the order of the group members determines the concatenation
|
||||
// order of inputs from different participants.
|
||||
static std::unique_ptr<HloInstruction> CreateAllGather(
|
||||
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
int64 all_gather_dimension,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
||||
|
||||
|
|
|
|||
|
|
@ -617,10 +617,11 @@ bool HloCollectiveInstruction::IdenticalSlowPathIgnoringChannelIdValues(
|
|||
}
|
||||
|
||||
HloAllGatherInstruction::HloAllGatherInstruction(
|
||||
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllGather, shape, {operand},
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
int64 all_gather_dimension, const std::vector<ReplicaGroup>& replica_groups,
|
||||
bool constrain_layout, const absl::optional<int64>& channel_id,
|
||||
bool use_global_device_ids)
|
||||
: HloCollectiveInstruction(HloOpcode::kAllGather, shape, operands,
|
||||
replica_groups, constrain_layout, channel_id),
|
||||
all_gather_dimension_(all_gather_dimension),
|
||||
use_global_device_ids_(use_global_device_ids) {}
|
||||
|
|
@ -641,7 +642,7 @@ HloAllGatherInstruction::CloneWithNewOperandsImpl(
|
|||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* /*context*/) const {
|
||||
return absl::make_unique<HloAllGatherInstruction>(
|
||||
shape, new_operands[0], all_gather_dimension(), replica_groups(),
|
||||
shape, new_operands, all_gather_dimension(), replica_groups(),
|
||||
constrain_layout(), channel_id(), use_global_device_ids());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -388,7 +388,8 @@ class HloCollectiveInstruction : public HloChannelInstruction {
|
|||
class HloAllGatherInstruction : public HloCollectiveInstruction {
|
||||
public:
|
||||
explicit HloAllGatherInstruction(
|
||||
const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
int64 all_gather_dimension,
|
||||
const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
|
||||
const absl::optional<int64>& channel_id, bool use_global_device_ids);
|
||||
// Same as HloAllReduceInstruction::use_global_device_ids.
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ namespace xla {
|
|||
V(kAdd, "add", 2) \
|
||||
V(kAddDependency, "add-dependency", 2) \
|
||||
V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
|
||||
V(kAllGather, "all-gather", 1) \
|
||||
V(kAllGather, "all-gather", kHloOpcodeIsVariadic) \
|
||||
V(kAllReduce, "all-reduce", kHloOpcodeIsVariadic) \
|
||||
V(kAllToAll, "all-to-all", kHloOpcodeIsVariadic) \
|
||||
V(kAtan2, "atan2", 2) \
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
|
|||
}
|
||||
switch (opcode) {
|
||||
case HloOpcode::kAfterAll:
|
||||
case HloOpcode::kAllGather:
|
||||
case HloOpcode::kAllReduce:
|
||||
case HloOpcode::kAllToAll:
|
||||
case HloOpcode::kCall:
|
||||
|
|
|
|||
|
|
@ -1217,7 +1217,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||
replica_groups = CreateReplicaGroups(*tmp_groups);
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateAllGather(
|
||||
shape, operands[0], dimensions->at(0), replica_groups,
|
||||
shape, operands, dimensions->at(0), replica_groups,
|
||||
constrain_layout ? *constrain_layout : false, channel_id,
|
||||
use_global_device_ids ? *use_global_device_ids : false));
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -276,11 +276,20 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
|||
ag->use_global_device_ids()));
|
||||
TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, group_mode));
|
||||
TF_RET_CHECK(ag->all_gather_dimension() >= 0);
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
|
||||
|
||||
for (int64_t i = 0; i < ag->operand_count(); ++i) {
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(i)->shape().rank());
|
||||
|
||||
const Shape& output_shape =
|
||||
(ag->operand_count() == 1) ? ag->shape() : ag->shape().tuple_shapes(i);
|
||||
TF_RET_CHECK(ag->all_gather_dimension() < output_shape.rank());
|
||||
}
|
||||
|
||||
const Shape& output0_shape =
|
||||
(ag->operand_count() == 1) ? ag->shape() : ag->shape().tuple_shapes(0);
|
||||
|
||||
int64 shard_count = CeilOfRatio(
|
||||
ag->shape().dimensions(ag->all_gather_dimension()),
|
||||
output0_shape.dimensions(ag->all_gather_dimension()),
|
||||
ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
|
||||
const HloModuleConfig& config = hlo->GetModule()->config();
|
||||
// empty replica groups imply all replicas form a single group.
|
||||
|
|
@ -312,9 +321,13 @@ Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
|
|||
<< "shard_count = " << shard_count
|
||||
<< ", subgroup_size = " << subgroup_size << ", " << hlo->ToString();
|
||||
|
||||
return CheckShape(ag, ShapeInference::InferAllGatherShape(
|
||||
ag->operand(0)->shape(), ag->all_gather_dimension(),
|
||||
shard_count));
|
||||
std::vector<const Shape*> operand_shapes;
|
||||
for (const HloInstruction* operand : hlo->operands()) {
|
||||
operand_shapes.push_back(&operand->shape());
|
||||
}
|
||||
return CheckShape(
|
||||
ag, ShapeInference::InferAllGatherShape(
|
||||
operand_shapes, ag->all_gather_dimension(), shard_count));
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
|
||||
|
|
|
|||
|
|
@ -2032,14 +2032,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
|||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferAllGatherShape(
|
||||
const Shape& operand_shape, int64 all_gather_dimension, int64 shard_count) {
|
||||
absl::Span<const Shape* const> operand_shapes, int64 all_gather_dimension,
|
||||
int64 shard_count) {
|
||||
TF_RET_CHECK(all_gather_dimension >= 0);
|
||||
TF_RET_CHECK(all_gather_dimension < operand_shape.rank());
|
||||
TF_RET_CHECK(shard_count > 0);
|
||||
auto shape = operand_shape;
|
||||
shape.set_dimensions(all_gather_dimension,
|
||||
shard_count * shape.dimensions(all_gather_dimension));
|
||||
return shape;
|
||||
|
||||
std::vector<Shape> output_shapes;
|
||||
output_shapes.reserve(operand_shapes.size());
|
||||
for (const Shape* operand_shape : operand_shapes) {
|
||||
TF_RET_CHECK(all_gather_dimension < operand_shape->rank());
|
||||
TF_RETURN_IF_ERROR(ExpectArray(*operand_shape, "operand of all-gather"));
|
||||
|
||||
Shape output_shape = *operand_shape;
|
||||
output_shape.set_dimensions(
|
||||
all_gather_dimension,
|
||||
shard_count * output_shape.dimensions(all_gather_dimension));
|
||||
output_shapes.push_back(output_shape);
|
||||
}
|
||||
if (output_shapes.size() == 1) {
|
||||
return output_shapes[0];
|
||||
}
|
||||
return ShapeUtil::MakeTupleShape(output_shapes);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
|
||||
|
|
|
|||
|
|
@ -127,9 +127,9 @@ class ShapeInference {
|
|||
|
||||
// Infers the shape produced by an all-gather with the given operand shape,
|
||||
// concat dimension, and shard count.
|
||||
static StatusOr<Shape> InferAllGatherShape(const Shape& operand_shape,
|
||||
int64 all_gather_dimension,
|
||||
int64 shard_count);
|
||||
static StatusOr<Shape> InferAllGatherShape(
|
||||
absl::Span<const Shape* const> operand_shapes, int64 all_gather_dimension,
|
||||
int64 shard_count);
|
||||
|
||||
// Infers the shape produced by a cross replica sum with the given operand
|
||||
// shapes.
|
||||
|
|
|
|||
|
|
@ -2099,8 +2099,6 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||
// Create the new convolution dim numbers.
|
||||
auto new_dim_numbers = permuted_conv_dims_numbers;
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size;
|
||||
|
||||
const int64 num_splits = kNumSplits;
|
||||
const int64 output_offsets = convolution->shape().dimensions(
|
||||
permuted_conv_dims_numbers.output_spatial_dimensions(
|
||||
|
|
@ -2111,6 +2109,8 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||
int64 spatial_split_size =
|
||||
CeilOfRatio(output_offsets_per_split, c.base_dilation_factor) * c.stride;
|
||||
|
||||
VLOG(1) << "spatial size " << c.spatial_size << " halo size " << c.halo_size
|
||||
<< " spatial_split_size " << spatial_split_size;
|
||||
// Keep increasing the split size so that overall size isn't smaller than the
|
||||
// original spatial dimension. Unlike for the first space-to-batch'ed
|
||||
// convolution, while propagating, we can use the last halo_size as available
|
||||
|
|
@ -2118,7 +2118,7 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||
while (spatial_split_size * num_splits + c.halo_size - c.spatial_size < 0) {
|
||||
spatial_split_size += c.stride;
|
||||
}
|
||||
|
||||
VLOG(1) << "Modified spatial_split_size " << spatial_split_size;
|
||||
const int64 new_space_size =
|
||||
activations_new->shape().dimensions(c.spatial_dimension_to_split);
|
||||
|
||||
|
|
@ -2132,23 +2132,14 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||
activations_new, activations_batch_dim, old_batch_size,
|
||||
c.spatial_dimension_to_split, spatial_split_size));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||
c.inherent_low_padding != 0
|
||||
? c.base_dilation_factor - 1
|
||||
: c.inherent_low_padding,
|
||||
c.inherent_high_padding,
|
||||
slice_size - spatial_split_size,
|
||||
old_split_dim_size));
|
||||
} else {
|
||||
// If the ideal spatial_split_size was smaller than the incoming spatial
|
||||
// dimension size, we don't need reshaping. Instead, we determine the
|
||||
// additional space available, and adjust the required slice size (and
|
||||
// thereby the halo size).
|
||||
VLOG(3) << "Decreasing the spatial size while propagating";
|
||||
VLOG(3)
|
||||
<< "Decreasing the spatial size while propagating spatial_split_size "
|
||||
<< spatial_split_size << " new_space_size " << new_space_size;
|
||||
if (spatial_split_size < new_space_size) {
|
||||
// If there's a stride mismatch, we change the new_space_size be
|
||||
// smaller (equal to spatial_split_size).
|
||||
|
|
@ -2167,20 +2158,21 @@ Status ConvolutionVisitor::PropagateOnConv(HloInstruction* convolution) {
|
|||
static_cast<int64>(0));
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
activations_new,
|
||||
HaloDuplicateWithSlice(activations_new, c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||
c.inherent_low_padding != 0
|
||||
? c.base_dilation_factor - 1
|
||||
: c.inherent_low_padding,
|
||||
c.inherent_high_padding,
|
||||
slice_size - spatial_split_size,
|
||||
old_split_dim_size));
|
||||
}
|
||||
|
||||
// For space-to-batch supported base-dilated convolutions, the low padding is
|
||||
// is passed on to the new convolutions. Halo does not have to account for it.
|
||||
TF_ASSIGN_OR_RETURN(activations_new,
|
||||
HaloDuplicateWithSlice(
|
||||
activations_new, c.spatial_dimension_to_split,
|
||||
activations_batch_dim, old_batch_size,
|
||||
/*low_padding=*/c.base_dilation_factor != 1 &&
|
||||
c.inherent_low_padding != 0
|
||||
? 0
|
||||
: c.inherent_low_padding,
|
||||
c.inherent_high_padding,
|
||||
slice_size - spatial_split_size, old_split_dim_size));
|
||||
|
||||
// We will generate output such that batch is followed by the split spatial
|
||||
// dimension.
|
||||
const int64 rank = (convolution->shape().rank());
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
|
|||
HloAllGatherInstruction* ag = DynCast<HloAllGatherInstruction>(hlo);
|
||||
// Only supporting AllGather on dimension 0 as it's the only case currently
|
||||
// happening and additional cases needs more complexity.
|
||||
if (!ag || ag->all_gather_dimension() != 0) {
|
||||
// TODO(cjfj): Support all-gathers with more than one operand.
|
||||
if (!ag || ag->all_gather_dimension() != 0 || ag->operand_count() > 1) {
|
||||
continue;
|
||||
}
|
||||
HloInstruction* real_data = ag->mutable_operand(0);
|
||||
|
|
@ -91,7 +92,7 @@ StatusOr<bool> CanonicalizeAllGatherForCSE::RunOnComputation(
|
|||
comp->AddInstruction(HloInstruction::CreateAllGather(
|
||||
ShapeUtil::MakeShape(real_data->shape().element_type(),
|
||||
new_dimensions),
|
||||
ag_input, /*all_gather_dimension=*/0, ag->replica_groups(),
|
||||
{ag_input}, /*all_gather_dimension=*/0, ag->replica_groups(),
|
||||
ag->constrain_layout(), new_channel_id,
|
||||
ag->use_global_device_ids()));
|
||||
HloInstruction* new_formatting = comp->AddInstruction(
|
||||
|
|
|
|||
|
|
@ -218,6 +218,7 @@ HloInstruction* SpmdBuilder::AddInstruction(
|
|||
HloInstruction* hlo =
|
||||
HloComputation::Builder::AddInstruction(std::move(instruction));
|
||||
if (visiting_hlo_) {
|
||||
hlo->set_metadata(visiting_hlo_->metadata());
|
||||
instructions_[visiting_hlo_].push_back(hlo);
|
||||
}
|
||||
if (hlo->opcode() == HloOpcode::kBroadcast) {
|
||||
|
|
@ -1397,7 +1398,6 @@ Status SpmdPartitioningVisitor::DefaultAction(HloInstruction* hlo) {
|
|||
auto clone =
|
||||
b_.AddInstruction(hlo->CloneWithNewOperands(hlo->shape(), new_operands));
|
||||
clone->set_sharding(sharding);
|
||||
clone->set_metadata(hlo->metadata());
|
||||
SetPartitionedHlo(hlo,
|
||||
PartitionedHlo(clone, hlo->shape(), MakePartitioningState())
|
||||
.Reshard(hlo->sharding()));
|
||||
|
|
@ -2139,7 +2139,6 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) {
|
|||
input_shard_shape.dimensions(input_sharded_dim) * merge_factor);
|
||||
auto tmp_reshape = b_.AddInstruction(
|
||||
HloInstruction::CreateReshape(tmp_shard_shape, operand.hlo()));
|
||||
tmp_reshape->set_metadata(hlo->metadata());
|
||||
tmp_reshape->set_sharding(hlo->sharding());
|
||||
auto tmp_full_shape = tmp_shard_shape;
|
||||
tmp_full_shape.set_dimensions(
|
||||
|
|
@ -2793,7 +2792,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) {
|
|||
}
|
||||
auto local_reduce = b_.AddInstruction(HloInstruction::CreateReduce(
|
||||
reduce_shape, input_hlos, inits, hlo->dimensions(), hlo->to_apply()));
|
||||
local_reduce->set_metadata(hlo->metadata());
|
||||
|
||||
SetPartitionedHlo(hlo, [&]() {
|
||||
HloInstruction* reduce = local_reduce;
|
||||
|
|
@ -3527,7 +3525,7 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64 num_partitions,
|
|||
}
|
||||
}
|
||||
return b->AddInstruction(HloInstruction::CreateAllGather(
|
||||
ag_shape, operand, all_gather_dimension, device_groups,
|
||||
ag_shape, {operand}, all_gather_dimension, device_groups,
|
||||
/*constrain_layout=*/false, channel_id,
|
||||
/*use_global_device_ids=*/true));
|
||||
},
|
||||
|
|
|
|||
|
|
@ -471,7 +471,6 @@ class SpmdPartitioningVisitor : public DfsHloVisitorWithDefault {
|
|||
const std::function<HloInstruction*()>& func) {
|
||||
HloInstruction* new_hlo = func();
|
||||
new_hlo->set_sharding(hlo->sharding());
|
||||
new_hlo->set_metadata(hlo->metadata());
|
||||
SetPartitionedHlo(
|
||||
hlo, PartitionedHlo(new_hlo, hlo->shape(), MakePartitioningState()));
|
||||
changed_ = true;
|
||||
|
|
|
|||
|
|
@ -665,6 +665,7 @@ cc_library(
|
|||
hdrs = [
|
||||
"meta_optimizer.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":arithmetic_optimizer",
|
||||
|
|
|
|||
|
|
@ -354,10 +354,13 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||
optimizers->push_back(
|
||||
MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKL
|
||||
if (BOTH_ARE_ON(scoped_allocator_optimization)) {
|
||||
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
|
||||
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
|
||||
}
|
||||
#endif
|
||||
|
||||
#undef USER_IS_ON
|
||||
#undef USER_NOT_OFF
|
||||
|
|
@ -680,11 +683,13 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
|
|||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
// Some optimizers can run only once.
|
||||
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
|
||||
#ifndef ENABLE_MKL
|
||||
// Some must run only on the last iteration.
|
||||
if (optimizer->name() == "scoped_allocator_optimizer") {
|
||||
if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
|
||||
optimized_graph, &optimization_result));
|
||||
|
|
@ -716,13 +721,14 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
|
|||
TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKL
|
||||
// ScopedAllocatorOptimizer must run last.
|
||||
if (sa_optimizer != nullptr) {
|
||||
TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
|
||||
optimized_graph, &optimization_result));
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
}
|
||||
#endif
|
||||
|
||||
bool is_optimized = std::find_if(optimization_result.results.begin(),
|
||||
optimization_result.results.end(),
|
||||
|
|
@ -1147,7 +1153,9 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
|||
rewrite_cfg.auto_parallel().enable() ||
|
||||
rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
|
||||
rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
|
||||
#ifndef ENABLE_MKL
|
||||
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
|
||||
#endif
|
||||
rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
|
||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
|
||||
AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
|
||||
|
|
|
|||
|
|
@ -356,6 +356,7 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
|
|||
return num_control_inputs;
|
||||
}
|
||||
};
|
||||
#ifndef ENABLE_MKL
|
||||
|
||||
TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
|
||||
// Tests that Rewrite of program with parallel unary Ops is done as
|
||||
|
|
@ -595,6 +596,7 @@ TEST_F(ScopedAllocatorOptimizerTest, ConstInput) {
|
|||
}
|
||||
EXPECT_EQ(num_identity_ops, 2);
|
||||
}
|
||||
#endif // ENABLE_MKL
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
|
|
|
|||
|
|
@ -816,6 +816,7 @@ cc_library(
|
|||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_ppc64le": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:macos_arm64": [],
|
||||
"//conditions:default": [
|
||||
"TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL",
|
||||
"TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL",
|
||||
|
|
@ -831,6 +832,7 @@ cc_library(
|
|||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_ppc64le": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:macos_arm64": [],
|
||||
"//conditions:default": ["@mkl_dnn_v1//:mkl_dnn"],
|
||||
}),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -157,8 +157,19 @@ class MklMatMulOp : public OpKernel {
|
|||
VLOG(2) << "MKL DNN SGEMM called";
|
||||
#ifndef ENABLE_ONEDNN_OPENMP
|
||||
MklDnnThreadPool eigen_tp(ctx);
|
||||
dnnl::threadpool_interop::sgemm(char_transa, char_transb, m, n, k, alpha, a,
|
||||
lda, b, ldb, beta, c, ldc, &eigen_tp);
|
||||
// With threadpool , the runtime overhead is comparable to the kernel
|
||||
// execution for small kernel sizes. For such sizes, it may be better to run
|
||||
// the kernel single threaded. Here we are coming up with a cost model based
|
||||
// on L1 sizes. If we find that matrices are small enough, we will execute
|
||||
// single threaded. This may need tuning.
|
||||
if (ExecuteSingleThreadedGemm(m, n, k)) {
|
||||
// For now, call single-threaded gemm.
|
||||
dnnl::threadpool_interop::sgemm(char_transa, char_transb, m, n, k, alpha,
|
||||
a, lda, b, ldb, beta, c, ldc, nullptr);
|
||||
} else {
|
||||
dnnl::threadpool_interop::sgemm(char_transa, char_transb, m, n, k, alpha,
|
||||
a, lda, b, ldb, beta, c, ldc, &eigen_tp);
|
||||
}
|
||||
#else
|
||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||
c, ldc);
|
||||
|
|
@ -188,7 +199,7 @@ class MklMatMulOp : public OpKernel {
|
|||
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
|
||||
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
|
||||
|
||||
// TODO(inteltf) Consider template specialization when adding/removing
|
||||
// TODO(intel-tf): Consider template specialization when adding/removing
|
||||
// additional types
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_CPU);
|
||||
|
|
|
|||
|
|
@ -33,8 +33,19 @@ using mkldnn::stream;
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
#define L1_SIZE 32 * 1024
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
inline bool ExecuteSingleThreadedGemm(int m, int n, int k) {
|
||||
// Ideally we would like to determine blocking and then come up with
|
||||
// a heuristic but what we are targeting are very small models whose
|
||||
// total size is < few L1's. So we will do this simple calculation
|
||||
// to determine if the matrix multiplication should be run on a single thread.
|
||||
constexpr int kHeuristicMultiplier = 8;
|
||||
return ((sizeof(float) * (m * n + k * (m + n))) <
|
||||
L1_SIZE * kHeuristicMultiplier);
|
||||
}
|
||||
|
||||
// This structure aggregates multiple inputs to MklDnnMatMul* methods.
|
||||
struct MklDnnMatMulFwdParams {
|
||||
memory::dims src_dims;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
#include "tensorflow/core/kernels/mlir_generated/base_ops_test.h"
|
||||
|
|
@ -734,6 +736,13 @@ TEST_F(UnaryOpsTest, TanhSmallAndLarge) {
|
|||
test::OpsTestConfig().ExpectStrictlyEqual().SuppressTolerance());
|
||||
}
|
||||
|
||||
TEST_F(UnaryOpsTest, TanhNaN) {
|
||||
Test<float, float, float, float>(
|
||||
"Tanh", test::DefaultInputShape(),
|
||||
test::InputAsVector<float>({std::numeric_limits<float>::quiet_NaN()}),
|
||||
std::tanh, test::OpsTestConfig().ExpectStrictlyEqual());
|
||||
}
|
||||
|
||||
/// Test `tf.Square`.
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -30,7 +30,9 @@ limitations under the License.
|
|||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/util/cuda_solvers.h" // For ScratchSpace
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "tensorflow/core/platform/rocm.h"
|
||||
#include "tensorflow/core/util/rocm_solvers.h"
|
||||
#endif
|
||||
|
||||
|
|
@ -326,6 +328,9 @@ class UniqueOpGPU : public AsyncOpKernel {
|
|||
const GPUDevice& device = context->eigen_gpu_device();
|
||||
int64 uniq_size = (*last_idx_host.data()) + 1;
|
||||
|
||||
se::cuda::ScopedActivateExecutorContext scoped_activation{
|
||||
context->op_device_context()->stream()->parent()};
|
||||
|
||||
Tensor unique_input_inds;
|
||||
TIndex* unique_input_inds_ptr = nullptr;
|
||||
AllocateTemp(context, uniq_size, &unique_input_inds,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ limitations under the License.
|
|||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 714 // Updated: 2021/3/23
|
||||
#define TF_GRAPH_DEF_VERSION 715 // Updated: 2021/3/24
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
|
|
|||
|
|
@ -469,7 +469,6 @@ class MklDnnShape {
|
|||
} else {
|
||||
auto format_tag =
|
||||
MklTensorFormatToMklDnnDataFormat(data_.tf_data_format_);
|
||||
DCHECK_NE(format_tag, memory::format_tag::undef);
|
||||
return memory::desc(dims, data_.T_, format_tag);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
57
tensorflow/lite/core/shims/cc_library_with_tflite.bzl
Normal file
57
tensorflow/lite/core/shims/cc_library_with_tflite.bzl
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
"""Definitions for cc_library/cc_test targets that use the TFLite shims."""
|
||||
|
||||
def cc_library_with_tflite(
|
||||
name,
|
||||
deps = [],
|
||||
tflite_deps = [],
|
||||
**kwargs):
|
||||
"""Defines a cc_library that uses the TFLite shims.
|
||||
|
||||
This is a hook to allow applying different build flags (etc.)
|
||||
for targets that use the TFLite shims.
|
||||
|
||||
Note that this build rule doesn't itself add any dependencies on
|
||||
TF Lite; this macro should normally be used in conjunction with a
|
||||
direct or indirect 'tflite_deps' dependency on one of the "shim"
|
||||
library targets from //tensorflow/lite/core/shims:*.
|
||||
|
||||
Args:
|
||||
name: as for cc_library.
|
||||
deps: as for cc_library.
|
||||
tflite_deps: dependencies on rules that are themselves defined using
|
||||
'cc_library_with_tflite'.
|
||||
**kwargs: Additional cc_library parameters.
|
||||
"""
|
||||
native.cc_library(
|
||||
name = name,
|
||||
deps = deps + tflite_deps,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def cc_test_with_tflite(
|
||||
name,
|
||||
deps = [],
|
||||
tflite_deps = [],
|
||||
**kwargs):
|
||||
"""Defines a cc_test that uses the TFLite shims.
|
||||
|
||||
This is a hook to allow applying different build flags (etc.)
|
||||
for targets that use the TFLite shims.
|
||||
|
||||
Note that this build rule doesn't itself add any dependencies on
|
||||
TF Lite this macro should normally be used in conjunction with a
|
||||
direct or indirect 'tflite_deps' dependency on one of the "shim"
|
||||
library targets from //third_party/tensorflow/lite/core/shims:*.
|
||||
|
||||
Args:
|
||||
name: as for cc_test.
|
||||
deps: as for cc_test.
|
||||
tflite_deps: dependencies on rules that are themselves defined using
|
||||
'cc_library_with_tflite'.
|
||||
**kwargs: Additional cc_test parameters.
|
||||
"""
|
||||
native.cc_test(
|
||||
name = name,
|
||||
deps = deps + tflite_deps,
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -434,6 +434,7 @@ cc_library(
|
|||
hdrs = ["linear_storage.h"],
|
||||
deps = [
|
||||
":cl_context",
|
||||
":cl_image_format",
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":util",
|
||||
|
|
@ -557,6 +558,7 @@ cc_library(
|
|||
deps = [
|
||||
":cl_command_queue",
|
||||
":cl_context",
|
||||
":cl_image_format",
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":util",
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ std::vector<cl_image_format> GetSupportedImage2DFormats(cl_context context,
|
|||
bool IsEqualToImageFormat(cl_image_format image_format, DataType data_type,
|
||||
int num_channels) {
|
||||
return image_format.image_channel_data_type ==
|
||||
ToImageChannelType(data_type) &&
|
||||
DataTypeToChannelType(data_type) &&
|
||||
image_format.image_channel_order == ToChannelOrder(num_channels);
|
||||
}
|
||||
|
||||
|
|
@ -131,7 +131,7 @@ bool CLContext::IsFloatTexture2DSupported(int num_channels, DataType data_type,
|
|||
cl_mem_flags flags) const {
|
||||
auto supported_formats = GetSupportedImage2DFormats(context_, flags);
|
||||
for (auto format : supported_formats) {
|
||||
if (format.image_channel_data_type == ToImageChannelType(data_type) &&
|
||||
if (format.image_channel_data_type == DataTypeToChannelType(data_type) &&
|
||||
format.image_channel_order == ToChannelOrder(num_channels)) {
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -34,14 +34,26 @@ cl_channel_order ToChannelOrder(int num_channels) {
|
|||
}
|
||||
}
|
||||
|
||||
cl_channel_type ToImageChannelType(DataType data_type) {
|
||||
switch (data_type) {
|
||||
cl_channel_type DataTypeToChannelType(DataType type, bool normalized) {
|
||||
switch (type) {
|
||||
case DataType::FLOAT32:
|
||||
return CL_FLOAT;
|
||||
case DataType::FLOAT16:
|
||||
return CL_HALF_FLOAT;
|
||||
case DataType::INT8:
|
||||
return normalized ? CL_SNORM_INT8 : CL_SIGNED_INT8;
|
||||
case DataType::UINT8:
|
||||
return normalized ? CL_UNORM_INT8 : CL_UNSIGNED_INT8;
|
||||
case DataType::INT16:
|
||||
return normalized ? CL_SNORM_INT16 : CL_SIGNED_INT16;
|
||||
case DataType::UINT16:
|
||||
return normalized ? CL_UNORM_INT16 : CL_UNSIGNED_INT16;
|
||||
case DataType::INT32:
|
||||
return CL_SIGNED_INT32;
|
||||
case DataType::UINT32:
|
||||
return CL_UNSIGNED_INT32;
|
||||
default:
|
||||
return -1;
|
||||
return CL_FLOAT;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ namespace cl {
|
|||
|
||||
cl_channel_order ToChannelOrder(int num_channels);
|
||||
|
||||
cl_channel_type ToImageChannelType(DataType data_type);
|
||||
cl_channel_type DataTypeToChannelType(DataType type, bool normalized = false);
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||
#include "tensorflow/lite/delegates/gpu/cl/linear_storage.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_image_format.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,8 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape,
|
|||
|
||||
cl_image_format format;
|
||||
format.image_channel_order = CL_RGBA;
|
||||
format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
|
||||
format.image_channel_data_type =
|
||||
DataTypeToChannelType(descriptor.data_type);
|
||||
|
||||
cl_int error_code;
|
||||
cl_mem memory =
|
||||
|
|
@ -97,7 +98,8 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape,
|
|||
|
||||
cl_image_format format;
|
||||
format.image_channel_order = CL_RGBA;
|
||||
format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
|
||||
format.image_channel_data_type =
|
||||
DataTypeToChannelType(descriptor.data_type);
|
||||
|
||||
cl_int error_code;
|
||||
cl_mem memory =
|
||||
|
|
@ -127,7 +129,8 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape,
|
|||
|
||||
cl_image_format format;
|
||||
format.image_channel_order = CL_RGBA;
|
||||
format.image_channel_data_type = ToImageChannelType(descriptor.data_type);
|
||||
format.image_channel_data_type =
|
||||
DataTypeToChannelType(descriptor.data_type);
|
||||
|
||||
cl_int error_code;
|
||||
cl_mem memory =
|
||||
|
|
@ -164,7 +167,7 @@ absl::Status AllocateTensorMemory(const CLContext& context, const BHWDC& shape,
|
|||
if (context.IsFloatTexture2DSupported(shape.c, descriptor.data_type)) {
|
||||
format.image_channel_order = ToChannelOrder(shape.c);
|
||||
format.image_channel_data_type =
|
||||
ToImageChannelType(descriptor.data_type);
|
||||
DataTypeToChannelType(descriptor.data_type);
|
||||
} else {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"This device doesn't support ", shape.c, "-channel textures."));
|
||||
|
|
@ -199,7 +202,7 @@ absl::Status CreateImageBufferFromBuffer(const CLContext& context,
|
|||
desc.image_width = width;
|
||||
desc.mem_object = memory;
|
||||
|
||||
format.image_channel_data_type = ToImageChannelType(data_type);
|
||||
format.image_channel_data_type = DataTypeToChannelType(data_type);
|
||||
format.image_channel_order = CL_RGBA;
|
||||
|
||||
cl_int error_code;
|
||||
|
|
@ -485,35 +488,33 @@ cl_mem Tensor::GetMemoryPtr() const {
|
|||
cl_mem Tensor::GetMemoryPtrForWriting() const { return memory_; }
|
||||
|
||||
absl::Status Tensor::WriteDataBHWDC(const float* in, CLCommandQueue* queue) {
|
||||
void* data_ptr = nullptr;
|
||||
const int aligned_channels = GetAlignedChannels();
|
||||
const int elements_count =
|
||||
shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
|
||||
|
||||
const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
|
||||
std::unique_ptr<float[]> data_f;
|
||||
std::unique_ptr<half[]> data_h;
|
||||
std::unique_ptr<uint8_t[]> data_copy;
|
||||
data_copy.reset(new uint8_t[data_size]);
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
data_f.reset(new float[elements_count]);
|
||||
data_ptr = data_f.get();
|
||||
DataFromBHWDC(in, shape_, descriptor_, data_f.get());
|
||||
DataFromBHWDC(in, shape_, descriptor_,
|
||||
reinterpret_cast<float*>(data_copy.get()));
|
||||
} else {
|
||||
data_h.reset(new half[elements_count]);
|
||||
data_ptr = data_h.get();
|
||||
DataFromBHWDC(in, shape_, descriptor_, data_h.get());
|
||||
DataFromBHWDC(in, shape_, descriptor_,
|
||||
reinterpret_cast<half*>(data_copy.get()));
|
||||
}
|
||||
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
RETURN_IF_ERROR(queue->EnqueueWriteBuffer(memory_, data_size, data_ptr));
|
||||
RETURN_IF_ERROR(
|
||||
queue->EnqueueWriteBuffer(memory_, data_size, data_copy.get()));
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
RETURN_IF_ERROR(
|
||||
queue->EnqueueWriteImage(memory_, GetFullTensorRegion(), data_ptr));
|
||||
RETURN_IF_ERROR(queue->EnqueueWriteImage(memory_, GetFullTensorRegion(),
|
||||
data_copy.get()));
|
||||
break;
|
||||
default:
|
||||
return absl::InternalError("Unsupported tensor storage type");
|
||||
|
|
@ -547,41 +548,36 @@ absl::Status Tensor::WriteData(CLCommandQueue* queue,
|
|||
}
|
||||
|
||||
absl::Status Tensor::ReadDataBHWDC(float* out, CLCommandQueue* queue) const {
|
||||
void* data_ptr = nullptr;
|
||||
const int aligned_channels = GetAlignedChannels();
|
||||
const int elements_count =
|
||||
shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
|
||||
const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
|
||||
std::unique_ptr<float[]> data_f;
|
||||
std::unique_ptr<half[]> data_h;
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
data_f.reset(new float[elements_count]);
|
||||
data_ptr = data_f.get();
|
||||
} else {
|
||||
data_h.reset(new half[elements_count]);
|
||||
data_ptr = data_h.get();
|
||||
}
|
||||
std::unique_ptr<uint8_t[]> data_copy;
|
||||
data_copy.reset(new uint8_t[data_size]);
|
||||
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
RETURN_IF_ERROR(queue->EnqueueReadBuffer(memory_, data_size, data_ptr));
|
||||
RETURN_IF_ERROR(
|
||||
queue->EnqueueReadBuffer(memory_, data_size, data_copy.get()));
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
RETURN_IF_ERROR(
|
||||
queue->EnqueueReadImage(memory_, GetFullTensorRegion(), data_ptr));
|
||||
RETURN_IF_ERROR(queue->EnqueueReadImage(memory_, GetFullTensorRegion(),
|
||||
data_copy.get()));
|
||||
break;
|
||||
default:
|
||||
return absl::InternalError("Unsupported tensor storage type");
|
||||
}
|
||||
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
DataToBHWDC(data_f.get(), shape_, descriptor_, out);
|
||||
DataToBHWDC(reinterpret_cast<float*>(data_copy.get()), shape_, descriptor_,
|
||||
out);
|
||||
} else {
|
||||
DataToBHWDC(data_h.get(), shape_, descriptor_, out);
|
||||
DataToBHWDC(reinterpret_cast<half*>(data_copy.get()), shape_, descriptor_,
|
||||
out);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/texture2d.h"
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/cl_image_format.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
|
|
|||
|
|
@ -184,29 +184,6 @@ absl::Status CreateCLBuffer(cl_context context, int size_in_bytes,
|
|||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
cl_channel_type DataTypeToChannelType(DataType type, bool normalized) {
|
||||
switch (type) {
|
||||
case DataType::FLOAT32:
|
||||
return CL_FLOAT;
|
||||
case DataType::FLOAT16:
|
||||
return CL_HALF_FLOAT;
|
||||
case DataType::INT8:
|
||||
return normalized ? CL_SNORM_INT8 : CL_SIGNED_INT8;
|
||||
case DataType::UINT8:
|
||||
return normalized ? CL_UNORM_INT8 : CL_UNSIGNED_INT8;
|
||||
case DataType::INT16:
|
||||
return normalized ? CL_SNORM_INT16 : CL_SIGNED_INT16;
|
||||
case DataType::UINT16:
|
||||
return normalized ? CL_UNORM_INT16 : CL_UNSIGNED_INT16;
|
||||
case DataType::INT32:
|
||||
return CL_SIGNED_INT32;
|
||||
case DataType::UINT32:
|
||||
return CL_UNSIGNED_INT32;
|
||||
default:
|
||||
return CL_FLOAT;
|
||||
}
|
||||
}
|
||||
|
||||
absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
|
||||
cl_channel_type channel_type, void* data,
|
||||
cl_mem* result) {
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ void CopyLinearFLT4(const tflite::gpu::Tensor<Linear, S>& src,
|
|||
absl::Status CreateCLBuffer(cl_context context, int size_in_bytes,
|
||||
bool read_only, void* data, cl_mem* result);
|
||||
|
||||
cl_channel_type DataTypeToChannelType(DataType type, bool normalized = false);
|
||||
absl::Status CreateRGBAImage2D(cl_context context, int width, int height,
|
||||
cl_channel_type channel_type, void* data,
|
||||
cl_mem* result);
|
||||
|
|
|
|||
|
|
@ -725,6 +725,7 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||
case OperationType::COS:
|
||||
case OperationType::ELU:
|
||||
case OperationType::EXP:
|
||||
case OperationType::FLOOR:
|
||||
case OperationType::LOG:
|
||||
case OperationType::NEG:
|
||||
case OperationType::RSQRT:
|
||||
|
|
@ -742,6 +743,8 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||
bool IsTwoArgumentOperation() const {
|
||||
switch (operation_type_) {
|
||||
case OperationType::DIV:
|
||||
case OperationType::FLOOR_DIV:
|
||||
case OperationType::FLOOR_MOD:
|
||||
case OperationType::MAXIMUM:
|
||||
case OperationType::MINIMUM:
|
||||
case OperationType::POW:
|
||||
|
|
@ -756,6 +759,8 @@ class ElementwiseOperationParser : public TFLiteOperationParser {
|
|||
bool IsTwoArgumentOperationWithConst() const {
|
||||
switch (operation_type_) {
|
||||
case OperationType::DIV:
|
||||
case OperationType::FLOOR_DIV:
|
||||
case OperationType::FLOOR_MOD:
|
||||
case OperationType::MAXIMUM:
|
||||
case OperationType::MINIMUM:
|
||||
case OperationType::POW:
|
||||
|
|
@ -2078,6 +2083,27 @@ class StridedSliceOperationParser : public TFLiteOperationParser {
|
|||
}
|
||||
};
|
||||
|
||||
class TileOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
absl::Status IsSupported(const TfLiteContext* context,
|
||||
const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration) final {
|
||||
RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
|
||||
/*runtime_inputs=*/1, /*outputs=*/1));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Parse(const TfLiteNode* tflite_node,
|
||||
const TfLiteRegistration* registration,
|
||||
GraphFloat32* graph, ObjectReader* reader) final {
|
||||
Node* node = graph->NewNode();
|
||||
node->operation.type = ToString(OperationType::TILE);
|
||||
RETURN_IF_ERROR(reader->AddInput(node, 0));
|
||||
RETURN_IF_ERROR(reader->AddOutputs(node));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
};
|
||||
|
||||
// Builtin op version of TRANSPOSE_CONV.
|
||||
class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
|
||||
public:
|
||||
|
|
@ -2367,6 +2393,14 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||
return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
|
||||
case kTfLiteBuiltinExp:
|
||||
return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
|
||||
case kTfLiteBuiltinFloor:
|
||||
return std::make_unique<ElementwiseOperationParser>(OperationType::FLOOR);
|
||||
case kTfLiteBuiltinFloorDiv:
|
||||
return std::make_unique<ElementwiseOperationParser>(
|
||||
OperationType::FLOOR_DIV);
|
||||
case kTfLiteBuiltinFloorMod:
|
||||
return std::make_unique<ElementwiseOperationParser>(
|
||||
OperationType::FLOOR_MOD);
|
||||
case kTfLiteBuiltinFullyConnected:
|
||||
return std::make_unique<FullyConnectedOperationParser>();
|
||||
case kTfLiteBuiltinHardSwish:
|
||||
|
|
@ -2456,6 +2490,8 @@ std::unique_ptr<TFLiteOperationParser> NewOperationParser(
|
|||
return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
|
||||
case kTfLiteBuiltinTanh:
|
||||
return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
|
||||
case kTfLiteBuiltinTile:
|
||||
return std::make_unique<TileOperationParser>();
|
||||
case kTfLiteBuiltinTranspose:
|
||||
return std::make_unique<TransposeOperationParser>();
|
||||
case kTfLiteBuiltinTransposeConv:
|
||||
|
|
|
|||
|
|
@ -987,9 +987,9 @@ int GetChannelsAlignment(const TensorDescriptor& desc, const BHWDC& shape) {
|
|||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DataFromBHWDC(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, T* dst) {
|
||||
template <typename FromType, typename ToType>
|
||||
void DataFromBHWDC(const FromType* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, ToType* dst) {
|
||||
const int channels_alignment = GetChannelsAlignment(desc, shape);
|
||||
const int slices = DivideRoundUp(shape.c, 4);
|
||||
for (int b = 0; b < shape.b; ++b) {
|
||||
|
|
@ -998,13 +998,13 @@ void DataFromBHWDC(const float* src, const BHWDC& shape,
|
|||
for (int x = 0; x < shape.w; ++x) {
|
||||
for (int d = 0; d < shape.d; ++d) {
|
||||
for (int c = 0; c < channels_alignment; ++c) {
|
||||
float value;
|
||||
FromType value;
|
||||
if (s * 4 + c < shape.c) {
|
||||
const int cpu_index =
|
||||
shape.LinearIndex({b, y, x, d, s * 4 + c});
|
||||
value = src[cpu_index];
|
||||
} else {
|
||||
value = 0.0f;
|
||||
value = 0;
|
||||
}
|
||||
int gpu_index = GetLinearIndex(desc, shape, b, x, y, d, s, c);
|
||||
dst[gpu_index] = value;
|
||||
|
|
@ -1016,14 +1016,16 @@ void DataFromBHWDC(const float* src, const BHWDC& shape,
|
|||
}
|
||||
}
|
||||
|
||||
template void DataFromBHWDC<float>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, float* dst);
|
||||
template void DataFromBHWDC<half>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, half* dst);
|
||||
template void DataFromBHWDC<float, float>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc,
|
||||
float* dst);
|
||||
template void DataFromBHWDC<float, half>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc,
|
||||
half* dst);
|
||||
|
||||
template <typename T>
|
||||
void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
|
||||
float* dst) {
|
||||
template <typename FromType, typename ToType>
|
||||
void DataToBHWDC(const FromType* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, ToType* dst) {
|
||||
const int channels_alignment = GetChannelsAlignment(desc, shape);
|
||||
const int slices = DivideRoundUp(shape.c, 4);
|
||||
for (int b = 0; b < shape.b; ++b) {
|
||||
|
|
@ -1046,10 +1048,12 @@ void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
|
|||
}
|
||||
}
|
||||
|
||||
template void DataToBHWDC<float>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, float* dst);
|
||||
template void DataToBHWDC<half>(const half* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, float* dst);
|
||||
template void DataToBHWDC<float, float>(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc,
|
||||
float* dst);
|
||||
template void DataToBHWDC<half, float>(const half* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc,
|
||||
float* dst);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -186,13 +186,13 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
|||
void UploadData(const float* src);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void DataFromBHWDC(const float* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, T* dst);
|
||||
template <typename FromType, typename ToType>
|
||||
void DataFromBHWDC(const FromType* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, ToType* dst);
|
||||
|
||||
template <typename T>
|
||||
void DataToBHWDC(const T* src, const BHWDC& shape, const TensorDescriptor& desc,
|
||||
float* dst);
|
||||
template <typename FromType, typename ToType>
|
||||
void DataToBHWDC(const FromType* src, const BHWDC& shape,
|
||||
const TensorDescriptor& desc, ToType* dst);
|
||||
|
||||
std::string ToString(TensorStorageType type);
|
||||
|
||||
|
|
|
|||
|
|
@ -149,6 +149,9 @@ class Convolution : public NodeShader {
|
|||
int SelectMultiplier(int32_t input_width,
|
||||
const NodeShader::GenerationContext& ctx) {
|
||||
std::vector<int> multipliers = {4, 2};
|
||||
if (ctx.gpu_info->IsAMD()) {
|
||||
return 1;
|
||||
}
|
||||
if (!ctx.compiler_options.allow_precision_loss && ctx.gpu_info->IsMali()) {
|
||||
multipliers = {2};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -381,37 +381,34 @@ int MetalSpatialTensor::GetAlignedChannels() const {
|
|||
|
||||
absl::Status MetalSpatialTensor::WriteDataBHWDC(id<MTLDevice> device,
|
||||
const float* in) {
|
||||
void* data_ptr = nullptr;
|
||||
const int aligned_channels = GetAlignedChannels();
|
||||
const int elements_count =
|
||||
shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
|
||||
|
||||
const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
|
||||
std::unique_ptr<float[]> data_f;
|
||||
std::unique_ptr<half[]> data_h;
|
||||
std::unique_ptr<uint8_t[]> data_copy;
|
||||
data_copy.reset(new uint8_t[data_size]);
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
data_f.reset(new float[elements_count]);
|
||||
data_ptr = data_f.get();
|
||||
DataFromBHWDC(in, shape_, descriptor_, data_f.get());
|
||||
DataFromBHWDC(in, shape_, descriptor_,
|
||||
reinterpret_cast<float*>(data_copy.get()));
|
||||
} else {
|
||||
data_h.reset(new half[elements_count]);
|
||||
data_ptr = data_h.get();
|
||||
DataFromBHWDC(in, shape_, descriptor_, data_h.get());
|
||||
DataFromBHWDC(in, shape_, descriptor_,
|
||||
reinterpret_cast<half*>(data_copy.get()));
|
||||
}
|
||||
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
std::memcpy([memory_ contents], data_ptr, data_size);
|
||||
std::memcpy([memory_ contents], data_copy.get(), data_size);
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
WriteDataToTexture2D(texture_mem_, device, data_ptr);
|
||||
WriteDataToTexture2D(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
WriteDataToTexture3D(texture_mem_, device, data_ptr);
|
||||
WriteDataToTexture3D(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
WriteDataToTexture2DArray(texture_mem_, device, data_ptr);
|
||||
WriteDataToTexture2DArray(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
default:
|
||||
|
|
@ -447,34 +444,26 @@ absl::Status MetalSpatialTensor::WriteData(id<MTLDevice> device,
|
|||
|
||||
absl::Status MetalSpatialTensor::ReadDataBHWDC(id<MTLDevice> device,
|
||||
float* out) const {
|
||||
void* data_ptr = nullptr;
|
||||
const int aligned_channels = GetAlignedChannels();
|
||||
const int elements_count =
|
||||
shape_.b * shape_.w * shape_.h * shape_.d * aligned_channels;
|
||||
const size_t data_size = elements_count * SizeOf(descriptor_.data_type);
|
||||
std::unique_ptr<float[]> data_f;
|
||||
std::unique_ptr<half[]> data_h;
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
data_f.reset(new float[elements_count]);
|
||||
data_ptr = data_f.get();
|
||||
} else {
|
||||
data_h.reset(new half[elements_count]);
|
||||
data_ptr = data_h.get();
|
||||
}
|
||||
std::unique_ptr<uint8_t[]> data_copy;
|
||||
data_copy.reset(new uint8_t[data_size]);
|
||||
|
||||
switch (descriptor_.storage_type) {
|
||||
case TensorStorageType::BUFFER:
|
||||
case TensorStorageType::IMAGE_BUFFER:
|
||||
std::memcpy(data_ptr, [memory_ contents], data_size);
|
||||
std::memcpy(data_copy.get(), [memory_ contents], data_size);
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_2D:
|
||||
ReadDataFromTexture2D(texture_mem_, device, data_ptr);
|
||||
ReadDataFromTexture2D(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_3D:
|
||||
ReadDataFromTexture3D(texture_mem_, device, data_ptr);
|
||||
ReadDataFromTexture3D(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::TEXTURE_ARRAY:
|
||||
ReadDataFromTexture2DArray(texture_mem_, device, data_ptr);
|
||||
ReadDataFromTexture2DArray(texture_mem_, device, data_copy.get());
|
||||
break;
|
||||
case TensorStorageType::SINGLE_TEXTURE_2D:
|
||||
default:
|
||||
|
|
@ -482,9 +471,11 @@ absl::Status MetalSpatialTensor::ReadDataBHWDC(id<MTLDevice> device,
|
|||
}
|
||||
|
||||
if (descriptor_.data_type == DataType::FLOAT32) {
|
||||
DataToBHWDC(data_f.get(), shape_, descriptor_, out);
|
||||
DataToBHWDC(reinterpret_cast<float*>(data_copy.get()), shape_, descriptor_,
|
||||
out);
|
||||
} else {
|
||||
DataToBHWDC(data_h.get(), shape_, descriptor_, out);
|
||||
DataToBHWDC(reinterpret_cast<half*>(data_copy.get()), shape_, descriptor_,
|
||||
out);
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
|
|
|||
|
|
@ -130,8 +130,8 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
|||
// Burst mode allows accelerators to efficiently manage resources, which
|
||||
// would significantly reduce overhead especially if the same delegate
|
||||
// instance is to be used for multiple inferences.
|
||||
// Default: Enabled.
|
||||
bool use_burst_computation = true;
|
||||
// Default: Disabled.
|
||||
bool use_burst_computation = false;
|
||||
};
|
||||
|
||||
// Uses default options.
|
||||
|
|
@ -259,7 +259,7 @@ class StatefulNnApiDelegate : public TfLiteDelegate {
|
|||
// Whether to allow dynamic dimension sizes without re-compilation.
|
||||
bool allow_dynamic_dimensions = false;
|
||||
// Whether to use NNAPI Burst mode.
|
||||
bool use_burst_computation = true;
|
||||
bool use_burst_computation = false;
|
||||
|
||||
explicit Data(const NnApi* nnapi);
|
||||
~Data();
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
import csv
|
||||
import io
|
||||
import re
|
||||
|
||||
from unittest import mock
|
||||
from absl.testing import parameterized
|
||||
|
|
@ -143,11 +144,15 @@ class QuantizationDebuggerTest(test_util.TensorFlowTestCase,
|
|||
'tensor_idx': 7 if quantized_io else 8,
|
||||
'scales': [0.15686275],
|
||||
'zero_points': [-128],
|
||||
'tensor_name': 'Identity' if quantized_io else 'Identity4'
|
||||
'tensor_name': r'Identity[1-9]?$'
|
||||
})
|
||||
for key, value in expected_values.items():
|
||||
if isinstance(value, str):
|
||||
self.assertEqual(value, actual_values[key])
|
||||
self.assertIsNotNone(
|
||||
re.match(value, actual_values[key]),
|
||||
'String is different from expected string. Please fix test code if'
|
||||
" it's being affected by graph manipulation changes."
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
self.assertAlmostEqual(
|
||||
value[0], float(actual_values[key][1:-1]), places=5)
|
||||
|
|
|
|||
|
|
@ -1093,6 +1093,7 @@ cc_test(
|
|||
name = "conv_mem_test",
|
||||
srcs = ["conv_mem_test.cc"],
|
||||
data = [
|
||||
"//tensorflow/lite:testdata/conv3d_huge_im2col.bin",
|
||||
"//tensorflow/lite:testdata/conv_huge_im2col.bin",
|
||||
],
|
||||
tags = [
|
||||
|
|
@ -1104,7 +1105,6 @@ cc_test(
|
|||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/profiling:memory_info",
|
||||
"//tensorflow/lite/tools:logging",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -21,33 +21,104 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/kernels/cpu_backend_context.h"
|
||||
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
|
||||
#include "tensorflow/lite/kernels/internal/types.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/padding.h"
|
||||
#include "tensorflow/lite/util.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace ops {
|
||||
namespace builtin {
|
||||
namespace conv3d {
|
||||
|
||||
enum KernelType {
|
||||
kReference,
|
||||
kGenericOptimized,
|
||||
};
|
||||
|
||||
// Struct to carry data from Prepare to Eval.
|
||||
const int kTensorNotAllocated = -1;
|
||||
static constexpr size_t kMaxIm2colBufferSizeMobile = 1024 * 1024 * 1024; // 1GB
|
||||
|
||||
struct OpData {
|
||||
Padding3DValues padding;
|
||||
int im2col_tensor_id = kTensorNotAllocated;
|
||||
int transposed_filter_tensor_id = kTensorNotAllocated;
|
||||
|
||||
bool need_im2col = false;
|
||||
bool need_transposed_filter = false;
|
||||
|
||||
// Disable im2col if the temporary im2col tensor requires too much memory
|
||||
// (i.e. >= kMaxIm2colBufferSizeMobile).
|
||||
bool im2col_oversized = false;
|
||||
|
||||
int32_t im2col_index;
|
||||
int32_t transposed_filter_index;
|
||||
};
|
||||
|
||||
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
|
||||
auto* data = new OpData;
|
||||
return data;
|
||||
auto* opdata = new OpData;
|
||||
return opdata;
|
||||
}
|
||||
|
||||
void Free(TfLiteContext* context, void* buffer) {
|
||||
delete static_cast<OpData*>(buffer);
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TfLiteStatus AllocateTemporaryTensorsIfRequired(
|
||||
KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
|
||||
OpData* opdata, TfLiteConv3DParams* params, const TfLiteTensor* filter,
|
||||
size_t im2col_bytes) {
|
||||
int temporaries_count = 0;
|
||||
const bool need_dilated_im2col = params->dilation_width_factor != 1 ||
|
||||
params->dilation_height_factor != 1 ||
|
||||
params->dilation_depth_factor != 1;
|
||||
const bool need_non_dilated_im2col =
|
||||
params->stride_depth != 1 || params->stride_width != 1 ||
|
||||
params->stride_height != 1 || filter->dims->data[2] != 1 ||
|
||||
filter->dims->data[1] != 1 || filter->dims->data[0] != 1;
|
||||
|
||||
opdata->need_im2col = (kernel_type == kGenericOptimized) &&
|
||||
(need_dilated_im2col || need_non_dilated_im2col);
|
||||
// TODO(b/183455632): Add transposing logic in converter so constant folding
|
||||
// might work on constant filter tensor.
|
||||
opdata->need_transposed_filter = (kernel_type == kGenericOptimized);
|
||||
|
||||
// On mobile platforms, the generic optimized kernel will not be used if the
|
||||
// temporary im2col tensor requires too much memory.
|
||||
if (IsMobilePlatform() && opdata->need_im2col &&
|
||||
im2col_bytes >= kMaxIm2colBufferSizeMobile) {
|
||||
opdata->need_im2col = false;
|
||||
opdata->need_transposed_filter = false;
|
||||
opdata->im2col_oversized = true;
|
||||
}
|
||||
|
||||
if (opdata->need_im2col && opdata->im2col_tensor_id == kTensorNotAllocated) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context, context->AddTensors(context, 1, &opdata->im2col_tensor_id));
|
||||
opdata->im2col_index = temporaries_count++;
|
||||
}
|
||||
|
||||
if (opdata->need_transposed_filter &&
|
||||
opdata->transposed_filter_tensor_id == kTensorNotAllocated) {
|
||||
TF_LITE_ENSURE_OK(
|
||||
context,
|
||||
context->AddTensors(context, 1, &opdata->transposed_filter_tensor_id));
|
||||
opdata->transposed_filter_index = temporaries_count++;
|
||||
}
|
||||
|
||||
TfLiteIntArrayFree(node->temporaries);
|
||||
node->temporaries = TfLiteIntArrayCreate(temporaries_count);
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
// Check number of inputs/outputs.
|
||||
TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
|
||||
|
|
@ -89,10 +160,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
int filter_depth = filter->dims->data[0];
|
||||
int filter_height = filter->dims->data[1];
|
||||
int filter_width = filter->dims->data[2];
|
||||
int input_channel = filter->dims->data[3];
|
||||
|
||||
// Matching GetWindowedOutputSize in TensorFlow.
|
||||
int out_width, out_height, out_depth;
|
||||
data->padding = ComputePadding3DValues(
|
||||
opdata->padding = ComputePadding3DValues(
|
||||
params->stride_height, params->stride_width, params->stride_depth,
|
||||
params->dilation_height_factor, params->dilation_width_factor,
|
||||
params->dilation_depth_factor, height, width, depth, filter_height,
|
||||
|
|
@ -105,12 +177,111 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||
output_size->data[2] = out_height;
|
||||
output_size->data[3] = out_width;
|
||||
output_size->data[4] = channels_out;
|
||||
return context->ResizeTensor(context, output, output_size);
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, output, output_size));
|
||||
|
||||
// Allocate temporary tensors.
|
||||
size_t input_type_size;
|
||||
TF_LITE_ENSURE_STATUS(GetSizeOfType(context, input->type, &input_type_size));
|
||||
const size_t im2col_bytes = batches * out_depth * out_height * out_width *
|
||||
input_channel * filter_depth * filter_height *
|
||||
filter_width * input_type_size;
|
||||
TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired(
|
||||
kernel_type, context, node, opdata, params,
|
||||
filter, im2col_bytes));
|
||||
|
||||
if (opdata->need_im2col) {
|
||||
TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(5);
|
||||
im2col_size->data[0] = output_size->data[0];
|
||||
im2col_size->data[1] = output_size->data[1];
|
||||
im2col_size->data[2] = output_size->data[2];
|
||||
im2col_size->data[3] = output_size->data[3];
|
||||
im2col_size->data[4] =
|
||||
input_channel * filter_depth * filter_height * filter_width;
|
||||
|
||||
TfLiteTensor* im2col;
|
||||
node->temporaries->data[opdata->im2col_index] = opdata->im2col_tensor_id;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
|
||||
opdata->im2col_index, &im2col));
|
||||
im2col->type = input->type;
|
||||
im2col->allocation_type = kTfLiteArenaRw;
|
||||
TF_LITE_ENSURE_OK(context,
|
||||
context->ResizeTensor(context, im2col, im2col_size));
|
||||
}
|
||||
|
||||
if (opdata->need_transposed_filter) {
|
||||
TfLiteIntArray* transposed_filter_size = TfLiteIntArrayCreate(5);
|
||||
transposed_filter_size->data[0] = filter->dims->data[4];
|
||||
transposed_filter_size->data[1] = filter->dims->data[0];
|
||||
transposed_filter_size->data[2] = filter->dims->data[1];
|
||||
transposed_filter_size->data[3] = filter->dims->data[2];
|
||||
transposed_filter_size->data[4] = filter->dims->data[3];
|
||||
|
||||
TfLiteTensor* transposed_filter;
|
||||
node->temporaries->data[opdata->transposed_filter_index] =
|
||||
opdata->transposed_filter_tensor_id;
|
||||
TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
|
||||
opdata->transposed_filter_index,
|
||||
&transposed_filter));
|
||||
transposed_filter->type = filter->type;
|
||||
transposed_filter->allocation_type = kTfLiteArenaRw;
|
||||
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, transposed_filter,
|
||||
transposed_filter_size));
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Prepare(kernel_type, context, node);
|
||||
}
|
||||
|
||||
void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteConv3DParams* params, OpData* opdata,
|
||||
const TfLiteTensor* input, const TfLiteTensor* filter,
|
||||
const TfLiteTensor* bias, TfLiteTensor* im2col,
|
||||
TfLiteTensor* tranposed_filter, TfLiteTensor* output) {
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
|
||||
Conv3DParams runtime_params;
|
||||
runtime_params.padding_values = opdata->padding;
|
||||
runtime_params.stride_depth = params->stride_depth;
|
||||
runtime_params.stride_height = params->stride_height;
|
||||
runtime_params.stride_width = params->stride_width;
|
||||
runtime_params.dilation_depth = params->dilation_depth_factor;
|
||||
runtime_params.dilation_height = params->dilation_height_factor;
|
||||
runtime_params.dilation_width = params->dilation_width_factor;
|
||||
runtime_params.float_activation_min = output_activation_min;
|
||||
runtime_params.float_activation_max = output_activation_max;
|
||||
switch (kernel_type) {
|
||||
case kReference: {
|
||||
reference_ops::Conv3D(runtime_params, GetTensorShape(input),
|
||||
GetTensorData<float>(input), GetTensorShape(filter),
|
||||
GetTensorData<float>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
break;
|
||||
}
|
||||
case kGenericOptimized: {
|
||||
optimized_ops::Conv3D(
|
||||
runtime_params, GetTensorShape(input), GetTensorData<float>(input),
|
||||
GetTensorShape(filter), GetTensorData<float>(filter),
|
||||
GetTensorShape(bias), GetTensorData<float>(bias),
|
||||
GetTensorShape(output), GetTensorData<float>(output),
|
||||
GetTensorShape(im2col), GetTensorData<float>(im2col),
|
||||
GetTensorShape(tranposed_filter),
|
||||
GetTensorData<float>(tranposed_filter),
|
||||
CpuBackendContext::GetFromContext(context));
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context,
|
||||
TfLiteNode* node) {
|
||||
auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data);
|
||||
OpData* data = reinterpret_cast<OpData*>(node->user_data);
|
||||
OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
|
||||
|
||||
TfLiteTensor* output;
|
||||
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
|
||||
|
|
@ -120,28 +291,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
|
||||
const TfLiteTensor* bias = GetInput(context, node, 2);
|
||||
|
||||
float output_activation_min, output_activation_max;
|
||||
CalculateActivationRange(params->activation, &output_activation_min,
|
||||
&output_activation_max);
|
||||
TfLiteTensor* im2col = opdata->need_im2col
|
||||
? &context->tensors[opdata->im2col_tensor_id]
|
||||
: nullptr;
|
||||
TfLiteTensor* transposed_filter =
|
||||
opdata->need_transposed_filter
|
||||
? &context->tensors[opdata->transposed_filter_tensor_id]
|
||||
: nullptr;
|
||||
|
||||
Conv3DParams runtime_params;
|
||||
runtime_params.padding_values = data->padding;
|
||||
runtime_params.stride_depth = params->stride_depth;
|
||||
runtime_params.stride_height = params->stride_height;
|
||||
runtime_params.stride_width = params->stride_width;
|
||||
runtime_params.dilation_depth = params->dilation_depth_factor;
|
||||
runtime_params.dilation_height = params->dilation_height_factor;
|
||||
runtime_params.dilation_width = params->dilation_width_factor;
|
||||
runtime_params.float_activation_min = output_activation_min;
|
||||
runtime_params.float_activation_max = output_activation_max;
|
||||
// Fallback to reference execution path when im2col is needed but disabled.
|
||||
if (opdata->im2col_oversized) {
|
||||
kernel_type = kReference;
|
||||
}
|
||||
|
||||
switch (input->type) {
|
||||
case kTfLiteFloat32:
|
||||
reference_ops::Conv3D(runtime_params, GetTensorShape(input),
|
||||
GetTensorData<float>(input), GetTensorShape(filter),
|
||||
GetTensorData<float>(filter), GetTensorShape(bias),
|
||||
GetTensorData<float>(bias), GetTensorShape(output),
|
||||
GetTensorData<float>(output));
|
||||
EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias,
|
||||
im2col, transposed_filter, output);
|
||||
break;
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
|
||||
|
|
@ -151,14 +317,31 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
template <KernelType kernel_type>
|
||||
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
return Eval(kernel_type, context, node);
|
||||
}
|
||||
|
||||
} // namespace conv3d
|
||||
|
||||
TfLiteRegistration* Register_CONV_3D() {
|
||||
static TfLiteRegistration r = {conv3d::Init, conv3d::Free, conv3d::Prepare,
|
||||
conv3d::Eval};
|
||||
TfLiteRegistration* Register_CONV_3D_REF() {
|
||||
static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
|
||||
conv3d::Prepare<conv3d::kReference>,
|
||||
conv3d::Eval<conv3d::kReference>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONV_3D_GENERIC_OPT() {
|
||||
static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
|
||||
conv3d::Prepare<conv3d::kGenericOptimized>,
|
||||
conv3d::Eval<conv3d::kGenericOptimized>};
|
||||
return &r;
|
||||
}
|
||||
|
||||
TfLiteRegistration* Register_CONV_3D() {
|
||||
return Register_CONV_3D_GENERIC_OPT();
|
||||
}
|
||||
|
||||
} // namespace builtin
|
||||
} // namespace ops
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class Conv3dOpModel : public SingleOpModel {
|
|||
BuildInterpreter({GetShape(input_), GetShape(filter_)});
|
||||
}
|
||||
|
||||
void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
|
||||
void SetFilter(std::vector<float> f) { PopulateTensor(filter_, f); }
|
||||
|
||||
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
|
||||
|
||||
|
|
@ -225,13 +225,13 @@ TEST(Conv3dOpModel, DilationTest) {
|
|||
/*dilation_height=*/2);
|
||||
|
||||
m.SetInput(CreateRangeVector<float>(96));
|
||||
m.SetFilter({1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1,
|
||||
1, -1, 1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, 1});
|
||||
m.SetFilter(CreateRangeVector<float>(32));
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(2, 1, 1, 3, 2));
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray({52, 8, 60, 8, 68, 8, 244, 8, 252, 8, 260, 8}));
|
||||
ElementsAreArray({7248, 7592, 7728, 8104, 8208, 8616, 18768,
|
||||
19880, 19248, 20392, 19728, 20904}));
|
||||
}
|
||||
|
||||
TEST(Conv3dOpModel, BiasTest) {
|
||||
|
|
@ -252,5 +252,22 @@ TEST(Conv3dOpModel, BiasTest) {
|
|||
ElementsAreArray({53, 10, 69, 10, 245, 10, 261, 10}));
|
||||
}
|
||||
|
||||
TEST(Conv3dOpModel, NoIm2ColTensorTest) {
|
||||
Conv3dOpModel m({TensorType_FLOAT32, {1, 2, 2, 2, 4}},
|
||||
{TensorType_FLOAT32, {1, 1, 1, 4, 4}},
|
||||
{TensorType_FLOAT32, {}}, Padding_VALID);
|
||||
|
||||
m.SetInput(CreateRangeVector<float>(32));
|
||||
m.SetFilter(CreateRangeVector<float>(16));
|
||||
m.Invoke();
|
||||
|
||||
EXPECT_THAT(m.GetOutputShape(), ElementsAre(1, 2, 2, 2, 4));
|
||||
EXPECT_THAT(
|
||||
m.GetOutput(),
|
||||
ElementsAreArray({56, 62, 68, 74, 152, 174, 196, 218, 248, 286, 324,
|
||||
362, 344, 398, 452, 506, 440, 510, 580, 650, 536, 622,
|
||||
708, 794, 632, 734, 836, 938, 728, 846, 964, 1082}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -20,55 +20,60 @@ limitations under the License.
|
|||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model_builder.h"
|
||||
#include "tensorflow/lite/profiling/memory_info.h"
|
||||
#include "tensorflow/lite/tools/logging.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
TEST(ConvMemUsage, HugeIm2ColData) {
|
||||
// As the test validates memory usage, skip if unsupported.
|
||||
if (!profiling::memory::MemoryUsage::IsSupported()) {
|
||||
return;
|
||||
}
|
||||
|
||||
void TestMemoryThreshold(const std::string& model_path,
|
||||
size_t threshold_in_kb) {
|
||||
// The Im2Col optimization is only applied on mobile platforms, so only
|
||||
// validate on such platforms.
|
||||
if (!IsMobilePlatform()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// The model has a conv op will require a temporary tensor of ~3.5GB if
|
||||
// The model has a conv op will require a huge temporary tensor if
|
||||
// im2col is performed and it's possible to cause OOM on devices. To prevent
|
||||
// this from happening, a size cap (i.e. kMaxIm2colBufferSizeMobile) of
|
||||
// to-be-allocated im2col data is used to determine whether to disable im2col.
|
||||
// This test will check the memory footprint before/after interpreter Invoke
|
||||
// to ensure the size cap is correctly enforced on mobile platforms.
|
||||
auto model = FlatBufferModel::BuildFromFile(
|
||||
"tensorflow/lite/testdata/conv_huge_im2col.bin");
|
||||
// to-be-allocated im2col data is used to determine whether to disable
|
||||
// im2col. This test will check the memory footprint before/after
|
||||
// interpreter Invoke to ensure the size cap is correctly enforced on mobile
|
||||
// platforms.
|
||||
auto model = FlatBufferModel::BuildFromFile(model_path.c_str());
|
||||
ASSERT_TRUE(model);
|
||||
|
||||
const auto mem_before = profiling::memory::GetMemoryUsage();
|
||||
std::unique_ptr<Interpreter> interpreter;
|
||||
|
||||
// Note that we explicitly set 1 thread here to avoid extra memory footprint
|
||||
// caused by multithreading, which will make the memory usage threshold check
|
||||
// later more reliable.
|
||||
// caused by multithreading, which will make the memory usage threshold
|
||||
// check later more reliable.
|
||||
ASSERT_EQ(InterpreterBuilder(*model, ops::builtin::BuiltinOpResolver())(
|
||||
&interpreter, /*num_threads*/ 1),
|
||||
kTfLiteOk);
|
||||
ASSERT_TRUE(interpreter);
|
||||
ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
|
||||
// Note that we skip invocation on such a large model as it can be
|
||||
// prohibitively slow for tests.
|
||||
const auto mem_after = profiling::memory::GetMemoryUsage();
|
||||
TFLITE_LOG(INFO) << "HugeIm2ColData Memory usage info: "
|
||||
<< mem_after - mem_before;
|
||||
|
||||
// The 3GB threshold is still < 3.5GB, chosen to suit different testing
|
||||
// configurations, such as MSan/TSan related tests where extra system-level
|
||||
// memory footprint usage might be counted as well. Note that the im2col
|
||||
// buffer limit is only applied *only on mobile platforms*.
|
||||
EXPECT_LE((mem_after - mem_before).max_rss_kb, 3 * 1024 * 1024);
|
||||
// Memory required for all tensors should be smaller than the threshold.
|
||||
int64_t accumulate_tensor_memory = 0;
|
||||
for (int i = 0; i < interpreter->tensors_size(); ++i) {
|
||||
accumulate_tensor_memory += interpreter->tensor(i)->bytes;
|
||||
}
|
||||
EXPECT_LE(accumulate_tensor_memory, threshold_in_kb * 1024);
|
||||
}
|
||||
|
||||
TEST(ConvMemUsage, HugeIm2ColData) {
|
||||
TestMemoryThreshold(
|
||||
// The model has a conv op will require a temporary tensor of ~3.5GB if
|
||||
// im2col is performed.
|
||||
"tensorflow/lite/testdata/conv_huge_im2col.bin",
|
||||
/*threshold_in_kb=*/3 * 1024 * 1024);
|
||||
}
|
||||
|
||||
TEST(Conv3DMemUsage, HugeIm2ColData) {
|
||||
TestMemoryThreshold(
|
||||
// The model has a Conv3D op will require a temporary tensor of ~1.3GB if
|
||||
// im2col is performed.If not, it will use about 450MB.
|
||||
"tensorflow/lite/testdata/conv3d_huge_im2col.bin",
|
||||
/*threshold_in_kb=*/1 * 1024 * 1024);
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
|
|
|
|||
|
|
@ -282,6 +282,229 @@ void Im2col(const ConvParams& params, int kheight, int kwidth,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void ExtractPatchIntoBufferColumn3D(
|
||||
int b, int d, int h, int w, // Output indexes.
|
||||
int kdepth, int kheight, int kwidth, // Kernel params.
|
||||
int stride_depth, int stride_height, int stride_width, // Stride params.
|
||||
int pad_depth, int pad_height, int pad_width, // Padding params.
|
||||
int in_depth, int in_height, int in_width, int in_channel, // Input shape.
|
||||
int output_row_offset, const T* in_data, T* conv_buffer_data,
|
||||
uint8 zero_byte) {
|
||||
ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn3D");
|
||||
|
||||
// This chunk of code reshapes all the inputs corresponding to
|
||||
// output (b, d, h, w) to a column vector in conv_buffer(:, buffer_id).
|
||||
const int id_ungated_start = d * stride_depth - pad_depth;
|
||||
const int id_start = std::max(0, id_ungated_start);
|
||||
const int id_ungated_end = (id_ungated_start + kdepth);
|
||||
const int id_end = std::min(id_ungated_end, in_depth);
|
||||
|
||||
const int ih_ungated_start = h * stride_height - pad_height;
|
||||
const int ih_start = std::max(0, ih_ungated_start);
|
||||
const int ih_ungated_end = (ih_ungated_start + kheight);
|
||||
const int ih_end = std::min(ih_ungated_end, in_height);
|
||||
|
||||
const int iw_ungated_start = w * stride_width - pad_width;
|
||||
const int iw_start = std::max(0, iw_ungated_start);
|
||||
const int iw_ungated_end = (iw_ungated_start + kwidth);
|
||||
const int iw_end = std::min(iw_ungated_end, in_width);
|
||||
|
||||
// Calculate the padding sizes.
|
||||
const int d_padding_before = std::max(0, -id_ungated_start);
|
||||
const int d_padding_after = (id_ungated_end - id_end);
|
||||
const int h_padding_before = std::max(0, -ih_ungated_start);
|
||||
const int h_padding_after = (ih_ungated_end - ih_end);
|
||||
const int w_padding_before = std::max(0, -iw_ungated_start);
|
||||
const int w_padding_after = (iw_ungated_end - iw_end);
|
||||
|
||||
// Memset if there are paddings in the depth dimension.
|
||||
const int kd_stride_size = kheight * kwidth * in_channel;
|
||||
const int id_stride_size = in_height * in_width * in_channel;
|
||||
|
||||
if (d_padding_before > 0) {
|
||||
const int d_padding_before_elements = (d_padding_before * kd_stride_size);
|
||||
memset(conv_buffer_data + output_row_offset, zero_byte,
|
||||
(d_padding_before_elements * sizeof(T)));
|
||||
}
|
||||
|
||||
if (d_padding_after > 0) {
|
||||
const int d_padding_after_elements = (d_padding_after * kd_stride_size);
|
||||
const int bottom_start =
|
||||
output_row_offset + (kdepth - d_padding_after) * kd_stride_size;
|
||||
memset(conv_buffer_data + bottom_start, zero_byte,
|
||||
(d_padding_after_elements * sizeof(T)));
|
||||
}
|
||||
|
||||
// If there are paddings in height or width dimension, menset the entire area
|
||||
// to take advantage of sequential memory handling performance.
|
||||
int out_offset = output_row_offset + d_padding_before * kd_stride_size;
|
||||
if (h_padding_before > 0 || h_padding_after > 0 || w_padding_before > 0 ||
|
||||
w_padding_after > 0) {
|
||||
const int middle_elements = (id_end - id_start) * kd_stride_size;
|
||||
memset(conv_buffer_data + out_offset, zero_byte,
|
||||
(middle_elements * sizeof(T)));
|
||||
}
|
||||
|
||||
// Copy the valid data from the input tensor.
|
||||
const int kh_stride_size = kwidth * in_channel;
|
||||
const int ih_stride_size = in_width * in_channel;
|
||||
const int h_padding = h_padding_before + h_padding_after;
|
||||
const int w_padding = w_padding_before + w_padding_after;
|
||||
const int single_row_num = (kwidth - w_padding) * in_channel;
|
||||
out_offset +=
|
||||
h_padding_before * kh_stride_size + w_padding_before * in_channel;
|
||||
const int in_offset_without_d = b * in_depth * id_stride_size +
|
||||
ih_start * ih_stride_size +
|
||||
iw_start * in_channel;
|
||||
for (int id = id_start; id < id_end; ++id) {
|
||||
int in_offset = in_offset_without_d + id * id_stride_size;
|
||||
for (int ih = ih_start; ih < ih_end; ++ih) {
|
||||
memcpy(conv_buffer_data + out_offset, in_data + in_offset,
|
||||
single_row_num * sizeof(T));
|
||||
out_offset += kh_stride_size;
|
||||
in_offset += ih_stride_size;
|
||||
}
|
||||
out_offset += h_padding * kh_stride_size;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Im2col3D(const Conv3DParams& params, int kdepth, int kheight, int kwidth,
|
||||
uint8 zero_byte, const RuntimeShape& input_shape,
|
||||
const T* input_data, const RuntimeShape& im2col_shape,
|
||||
T* im2col_data) {
|
||||
ruy::profiler::ScopeLabel label("Im2col3D");
|
||||
const int stride_depth = params.stride_depth;
|
||||
const int stride_width = params.stride_width;
|
||||
const int stride_height = params.stride_height;
|
||||
const int pad_depth = params.padding_values.depth;
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 5);
|
||||
|
||||
const int batches = MatchingDim(input_shape, 0, im2col_shape, 0);
|
||||
const int input_depth = input_shape.Dims(1);
|
||||
const int input_height = input_shape.Dims(2);
|
||||
const int input_width = input_shape.Dims(3);
|
||||
const int input_channel = input_shape.Dims(4);
|
||||
|
||||
const int output_depth = im2col_shape.Dims(1);
|
||||
const int output_height = im2col_shape.Dims(2);
|
||||
const int output_width = im2col_shape.Dims(3);
|
||||
const int output_channel = im2col_shape.Dims(4);
|
||||
|
||||
int buffer_id = 0;
|
||||
// Loop over the output nodes.
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int d = 0; d < output_depth; ++d) {
|
||||
for (int h = 0; h < output_height; ++h) {
|
||||
for (int w = 0; w < output_width; ++w) {
|
||||
ExtractPatchIntoBufferColumn3D(
|
||||
b, d, h, w, kdepth, kheight, kwidth, stride_depth, stride_height,
|
||||
stride_width, pad_depth, pad_height, pad_width, input_depth,
|
||||
input_height, input_width, input_channel, buffer_id, input_data,
|
||||
im2col_data, zero_byte);
|
||||
buffer_id += output_channel;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void DilatedIm2col3D(const Conv3DParams& params, int filter_depth,
|
||||
int filter_height, int filter_width,
|
||||
uint8 zero_byte, const RuntimeShape& input_shape,
|
||||
const T* input_data,
|
||||
const RuntimeShape& im2col_shape, T* im2col_data) {
|
||||
ruy::profiler::ScopeLabel label("DilatedIm2col3D");
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 5);
|
||||
|
||||
// Only NDHWC format is currently supported.
|
||||
const int batches = MatchingDim(input_shape, 0, im2col_shape, 0);
|
||||
const int input_channels = input_shape.Dims(4);
|
||||
const int input_width = input_shape.Dims(3);
|
||||
const int input_height = input_shape.Dims(2);
|
||||
const int input_depth = input_shape.Dims(1);
|
||||
|
||||
const int output_width = im2col_shape.Dims(3);
|
||||
const int output_height = im2col_shape.Dims(2);
|
||||
const int output_depth = im2col_shape.Dims(1);
|
||||
|
||||
const int pad_width = params.padding_values.width;
|
||||
const int pad_height = params.padding_values.height;
|
||||
const int pad_depth = params.padding_values.depth;
|
||||
|
||||
// Construct the MxN sized im2col matrix.
|
||||
// The rows M, are sub-ordered B x D x H x W.
|
||||
const RuntimeShape row_shape(
|
||||
{1, batches, output_depth, output_height, output_width});
|
||||
// The columns, N, are sub-ordered Kd x Kh x Kw x Din.
|
||||
const RuntimeShape col_shape(
|
||||
{1, filter_depth, filter_height, filter_width, input_channels});
|
||||
// Use dimensions M and N to construct dims for indexing directly into im2col.
|
||||
const RuntimeShape im2col_reshaped(
|
||||
{1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
|
||||
|
||||
for (int batch = 0; batch < batches; ++batch) {
|
||||
for (int out_d = 0; out_d < output_depth; ++out_d) {
|
||||
const int in_d_origin = (out_d * params.stride_depth) - pad_depth;
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
const int in_y_origin = (out_y * params.stride_height) - pad_height;
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * params.stride_width) - pad_width;
|
||||
const int row_offset =
|
||||
Offset(row_shape, 0, batch, out_d, out_y, out_x);
|
||||
for (int filter_d = 0; filter_d < filter_depth; ++filter_d) {
|
||||
const int in_d = in_d_origin + params.dilation_depth * filter_d;
|
||||
if ((in_d >= 0) && (in_d < input_depth)) {
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int in_y =
|
||||
in_y_origin + params.dilation_height * filter_y;
|
||||
if ((in_y >= 0) && (in_y < input_height)) {
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
const int in_x =
|
||||
in_x_origin + params.dilation_width * filter_x;
|
||||
int col_offset =
|
||||
Offset(col_shape, 0, filter_d, filter_y, filter_x, 0);
|
||||
T* dst = im2col_data + Offset(im2col_reshaped, 0, 0,
|
||||
row_offset, col_offset);
|
||||
if ((in_x >= 0) && (in_x < input_width)) {
|
||||
// Filter pixel is within the input, copy the input data.
|
||||
T const* src = input_data + Offset(input_shape, batch,
|
||||
in_d, in_y, in_x, 0);
|
||||
memcpy(dst, src, input_depth * sizeof(T));
|
||||
} else {
|
||||
// Filter pixel is outside the input, zero it out.
|
||||
memset(dst, zero_byte, input_depth * sizeof(T));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int col_offset =
|
||||
Offset(col_shape, 0, filter_d, filter_y, 0, 0);
|
||||
T* dst = im2col_data + Offset(im2col_reshaped, 0, 0,
|
||||
row_offset, col_offset);
|
||||
memset(dst, zero_byte,
|
||||
filter_width * input_depth * sizeof(T));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int col_offset = Offset(col_shape, 0, filter_d, 0, 0, 0);
|
||||
T* dst = im2col_data +
|
||||
Offset(im2col_reshaped, 0, 0, row_offset, col_offset);
|
||||
memset(dst, zero_byte,
|
||||
filter_height * filter_width * input_depth * sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
|
|
|||
|
|
@ -7960,6 +7960,98 @@ inline void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
|
|||
ArgMax(input1_shape, input1_data, input2_data, output_shape, output_data);
|
||||
}
|
||||
|
||||
inline void Conv3D(const Conv3DParams& params, const RuntimeShape& input_shape,
|
||||
const float* input_data, const RuntimeShape& filter_shape,
|
||||
const float* filter_data, const RuntimeShape& bias_shape,
|
||||
const float* bias_data, const RuntimeShape& output_shape,
|
||||
float* output_data, const RuntimeShape& im2col_shape,
|
||||
float* im2col_data,
|
||||
const RuntimeShape& transposed_filter_shape,
|
||||
float* transposed_filter_data,
|
||||
CpuBackendContext* cpu_backend_context) {
|
||||
const int stride_depth = params.stride_depth;
|
||||
const int stride_height = params.stride_height;
|
||||
const int stride_width = params.stride_width;
|
||||
const int dilation_depth_factor = params.dilation_depth;
|
||||
const int dilation_height_factor = params.dilation_height;
|
||||
const int dilation_width_factor = params.dilation_width;
|
||||
const float output_activation_min = params.float_activation_min;
|
||||
const float output_activation_max = params.float_activation_max;
|
||||
TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 5);
|
||||
TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 5);
|
||||
|
||||
ruy::profiler::ScopeLabel label("Conv3D");
|
||||
|
||||
// NB: the float 0.0f value is represented by all zero bytes.
|
||||
const uint8 float_zero_byte = 0x00;
|
||||
const float* gemm_input_data = nullptr;
|
||||
const RuntimeShape* gemm_input_shape = nullptr;
|
||||
const int filter_width = filter_shape.Dims(2);
|
||||
const int filter_height = filter_shape.Dims(1);
|
||||
const int filter_depth = filter_shape.Dims(0);
|
||||
const bool need_dilated_im2col = dilation_width_factor != 1 ||
|
||||
dilation_height_factor != 1 ||
|
||||
dilation_depth_factor != 1;
|
||||
const bool need_im2col = stride_depth != 1 || stride_height != 1 ||
|
||||
stride_width != 1 || filter_depth != 1 ||
|
||||
filter_height != 1 || filter_width != 1;
|
||||
|
||||
if (need_dilated_im2col) {
|
||||
DilatedIm2col3D(params, filter_depth, filter_height, filter_width,
|
||||
float_zero_byte, input_shape, input_data, im2col_shape,
|
||||
im2col_data);
|
||||
gemm_input_data = im2col_data;
|
||||
gemm_input_shape = &im2col_shape;
|
||||
} else if (need_im2col) {
|
||||
TFLITE_DCHECK(im2col_data);
|
||||
Im2col3D(params, filter_depth, filter_height, filter_width, float_zero_byte,
|
||||
input_shape, input_data, im2col_shape, im2col_data);
|
||||
gemm_input_data = im2col_data;
|
||||
gemm_input_shape = &im2col_shape;
|
||||
} else {
|
||||
TFLITE_DCHECK(!im2col_data);
|
||||
gemm_input_data = input_data;
|
||||
gemm_input_shape = &input_shape;
|
||||
}
|
||||
|
||||
// Transpose the filter tensor.
|
||||
TransposeParams transpose_params;
|
||||
transpose_params.perm_count = 5;
|
||||
transpose_params.perm[0] = 4;
|
||||
transpose_params.perm[1] = 0;
|
||||
transpose_params.perm[2] = 1;
|
||||
transpose_params.perm[3] = 2;
|
||||
transpose_params.perm[4] = 3;
|
||||
Transpose<float, 5>(transpose_params, filter_shape, filter_data,
|
||||
transposed_filter_shape, transposed_filter_data);
|
||||
|
||||
const int gemm_input_dims = gemm_input_shape->DimensionsCount();
|
||||
int m = FlatSizeSkipDim(*gemm_input_shape, gemm_input_dims - 1);
|
||||
int n = output_shape.Dims(4);
|
||||
int k = gemm_input_shape->Dims(gemm_input_dims - 1);
|
||||
|
||||
cpu_backend_gemm::MatrixParams<float> lhs_params;
|
||||
lhs_params.order = cpu_backend_gemm::Order::kRowMajor;
|
||||
lhs_params.rows = n;
|
||||
lhs_params.cols = k;
|
||||
cpu_backend_gemm::MatrixParams<float> rhs_params;
|
||||
rhs_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
rhs_params.rows = k;
|
||||
rhs_params.cols = m;
|
||||
cpu_backend_gemm::MatrixParams<float> dst_params;
|
||||
dst_params.order = cpu_backend_gemm::Order::kColMajor;
|
||||
dst_params.rows = n;
|
||||
dst_params.cols = m;
|
||||
cpu_backend_gemm::GemmParams<float, float> gemm_params;
|
||||
gemm_params.bias = bias_data;
|
||||
gemm_params.clamp_min = output_activation_min;
|
||||
gemm_params.clamp_max = output_activation_max;
|
||||
cpu_backend_gemm::Gemm(lhs_params, transposed_filter_data, rhs_params,
|
||||
gemm_input_data, dst_params, output_data, gemm_params,
|
||||
cpu_backend_context);
|
||||
}
|
||||
|
||||
} // namespace optimized_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ TfLiteRegistration* Register_DEPTH_TO_SPACE_REF();
|
|||
TfLiteRegistration* Register_SELECT_V2();
|
||||
TfLiteRegistration* Register_SEGMENT_SUM();
|
||||
TfLiteRegistration* Register_BROADCAST_TO();
|
||||
TfLiteRegistration* Register_CONV_3D();
|
||||
TfLiteRegistration* Register_CONV_3D_REF();
|
||||
TfLiteRegistration* Register_IMAG();
|
||||
TfLiteRegistration* Register_REAL();
|
||||
TfLiteRegistration* Register_COMPLEX_ABS();
|
||||
|
|
@ -465,7 +465,7 @@ BuiltinRefOpResolver::BuiltinRefOpResolver() {
|
|||
AddBuiltin(BuiltinOperator_BATCH_MATMUL, Register_BATCH_MATMUL_REF(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 3);
|
||||
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D());
|
||||
AddBuiltin(BuiltinOperator_CONV_3D, Register_CONV_3D_REF());
|
||||
AddBuiltin(BuiltinOperator_IMAG, Register_IMAG());
|
||||
AddBuiltin(BuiltinOperator_REAL, Register_REAL());
|
||||
AddBuiltin(BuiltinOperator_COMPLEX_ABS, Register_COMPLEX_ABS());
|
||||
|
|
|
|||
|
|
@ -37,9 +37,16 @@ namespace tflite {
|
|||
using KeywordBenchmarkRunner = MicroBenchmarkRunner<int16_t>;
|
||||
using KeywordOpResolver = MicroMutableOpResolver<6>;
|
||||
|
||||
#if defined(HEXAGON)
|
||||
// TODO(b/174781826): reduce arena usage for optimized Hexagon kernels.
|
||||
constexpr int kOptimizedKernelArenaIncrement = 21000;
|
||||
#else
|
||||
constexpr int kOptimizedKernelArenaIncrement = 0;
|
||||
#endif
|
||||
|
||||
// Create an area of memory to use for input, output, and intermediate arrays.
|
||||
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
|
||||
constexpr int kTensorArenaSize = 21 * 1024;
|
||||
constexpr int kTensorArenaSize = 21 * 1024 + kOptimizedKernelArenaIncrement;
|
||||
alignas(16) uint8_t tensor_arena[kTensorArenaSize];
|
||||
|
||||
uint8_t benchmark_runner_buffer[sizeof(KeywordBenchmarkRunner)];
|
||||
|
|
|
|||
26
tensorflow/lite/micro/ceva/micro_time.cc
Normal file
26
tensorflow/lite/micro/ceva/micro_time.cc
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/micro_time.h"
|
||||
|
||||
#include <ceva-time.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
int32_t ticks_per_second() { return 100e6; }
|
||||
|
||||
int32_t GetCurrentTimeTicks() { return clock(); }
|
||||
|
||||
} // namespace tflite
|
||||
28
tensorflow/lite/micro/ceva/system_setup.cc
Normal file
28
tensorflow/lite/micro/ceva/system_setup.cc
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/micro/system_setup.h"
|
||||
|
||||
#include <ceva-time.h>
|
||||
|
||||
namespace tflite {
|
||||
|
||||
void InitializeTarget() {
|
||||
// start clock for profiler
|
||||
reset_clock();
|
||||
start_clock();
|
||||
}
|
||||
|
||||
} // namespace tflite
|
||||
7
tensorflow/lite/micro/tools/make/ext_libs/hexagon.inc
Normal file
7
tensorflow/lite/micro/tools/make/ext_libs/hexagon.inc
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
HEXAGON_TFLM_LIB_PATH = tensorflow/lite/micro/kernels/hexagon/lib/
|
||||
HEXAGON_TFLM_INC_PATH = tensorflow/lite/micro/kernels/hexagon/inc/
|
||||
|
||||
HEXAGON_TFLM_CORE_LIB_NAME = hexagon_tflm_core.a
|
||||
HEXAGON_TFLM_CORE_LIB_FULLNAME = $(HEXAGON_TFLM_LIB_PATH)$(HEXAGON_TFLM_CORE_LIB_NAME)
|
||||
MICROLITE_LIBS += $(HEXAGON_TFLM_CORE_LIB_FULLNAME)
|
||||
INCLUDES += -I$(HEXAGON_TFLM_INC_PATH)
|
||||
13
tensorflow/lite/micro/tools/make/targets/hexagon/download_hexagon.sh
Normal file → Executable file
13
tensorflow/lite/micro/tools/make/targets/hexagon/download_hexagon.sh
Normal file → Executable file
|
|
@ -24,18 +24,17 @@
|
|||
# Clone hexagon kernels to temp directory and check out known-good commit.
|
||||
HEXAGON_DIR=/tmp/hexagon_optimized
|
||||
|
||||
mkdir -p ${HEXAGON_DIR}
|
||||
if [ ! -d ${HEXAGON_DIR} ]; then
|
||||
mkdir -p ${HEXAGON_DIR}
|
||||
git clone -b release_v2 https://source.codeaurora.org/quic/embedded_ai/tensorflow ${HEXAGON_DIR}
|
||||
fi
|
||||
|
||||
git clone -b release_v2 https://source.codeaurora.org/quic/embedded_ai/tensorflow ${HEXAGON_DIR}
|
||||
pushd ${HEXAGON_DIR}
|
||||
pushd ${HEXAGON_DIR} > /dev/null
|
||||
git checkout 2d052806c211144875c89315a4fc6f1393064cf6
|
||||
popd
|
||||
popd > /dev/null
|
||||
|
||||
# Copy optimized kernels from checkout, copy prebuilt lib.
|
||||
rm -rf tensorflow/lite/micro/kernels/hexagon
|
||||
cp -R ${HEXAGON_DIR}/tensorflow/lite/micro/kernels/hexagon tensorflow/lite/micro/kernels/hexagon
|
||||
cp -R ${HEXAGON_DIR}/tensorflow/lite/micro/hexagon tensorflow/lite/micro
|
||||
cp ${HEXAGON_DIR}/tensorflow/lite/micro/tools/make/ext_libs/hexagon_library.inc tensorflow/lite/micro/tools/make/ext_libs/hexagon_library.inc
|
||||
cp ${HEXAGON_DIR}/tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc tensorflow/lite/micro/tools/make/targets/hexagon_makefile.inc
|
||||
mkdir tensorflow/lite/micro/kernels/hexagon/lib
|
||||
cp ${1} tensorflow/lite/micro/kernels/hexagon/lib/
|
||||
|
|
|
|||
|
|
@ -18,75 +18,82 @@
|
|||
# Unlike other targets, there is not currently a way to automatically download
|
||||
# the Hexagon SDK. For this reason, users are required to manually download
|
||||
# and configure the SDK.
|
||||
ifeq ($(TARGET), hexagon)
|
||||
TARGET_ARCH := hexagon
|
||||
|
||||
ifndef HEXAGON_SDK_ROOT
|
||||
$(error HEXAGON_SDK_ROOT is undefined)
|
||||
endif
|
||||
TARGET_ARCH := hexagon
|
||||
|
||||
ifndef HEXAGON_TOOL_VER
|
||||
$(error HEXAGON_TOOL_VER is undefined)
|
||||
endif
|
||||
|
||||
ifndef HEXAGON_ROOT
|
||||
$(error HEXAGON_ROOT is undefined)
|
||||
endif
|
||||
|
||||
ifndef HEXAGON_CPU_VER
|
||||
$(error HEXAGON_CPU_VER is undefined)
|
||||
endif
|
||||
|
||||
PLATFORM_ARGS = \
|
||||
-DTF_LITE_MCU_DEBUG_LOG \
|
||||
-DTF_LITE_USE_CTIME \
|
||||
-DHEXAGON_ASM \
|
||||
-DMALLOC_IN_STDLIB \
|
||||
-DPTHREAD_STUBS \
|
||||
-DUSE_PREALLOCATED_BUFFER \
|
||||
-D_HAS_C9X \
|
||||
-MMD \
|
||||
-DHEXAGON \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wno-missing-field-initializers \
|
||||
-Wno-sign-compare \
|
||||
-Wno-unused-parameter \
|
||||
-Wno-write-strings \
|
||||
-Wunused-function \
|
||||
-Wno-unused-private-field \
|
||||
-Wvla \
|
||||
-fdata-sections \
|
||||
-ffunction-sections \
|
||||
-fmessage-length=0 \
|
||||
-fno-delete-null-pointer-checks \
|
||||
-fno-exceptions \
|
||||
-fno-register-global-dtors-with-atexit \
|
||||
-fno-rtti \
|
||||
-fno-short-enums \
|
||||
-fno-threadsafe-statics \
|
||||
-fno-unwind-tables \
|
||||
-fno-use-cxa-atexit \
|
||||
-fomit-frame-pointer \
|
||||
-fpermissive \
|
||||
-funsigned-char \
|
||||
-mcpu=$(HEXAGON_CPU_VER) \
|
||||
-m$(HEXAGON_CPU_VER)
|
||||
|
||||
export PATH := $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/bin:$(PATH)
|
||||
TARGET_TOOLCHAIN_PREFIX := hexagon-
|
||||
CXX_TOOL := clang++
|
||||
CC_TOOL := clang
|
||||
|
||||
CXXFLAGS += $(PLATFORM_ARGS)
|
||||
CCFLAGS += $(PLATFORM_ARGS)
|
||||
LDFLAGS += \
|
||||
-Wl,--gc-sections -lhexagon \
|
||||
$(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/target/hexagon/lib/v66/libstdc++.a
|
||||
|
||||
INCLUDES += \
|
||||
-I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/posix \
|
||||
-I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/qurt
|
||||
|
||||
TEST_SCRIPT := tensorflow/lite/micro/testing/test_hexagon_binary.sh
|
||||
ifndef HEXAGON_SDK_ROOT
|
||||
$(error HEXAGON_SDK_ROOT is undefined)
|
||||
endif
|
||||
|
||||
ifndef HEXAGON_TOOL_VER
|
||||
$(error HEXAGON_TOOL_VER is undefined)
|
||||
endif
|
||||
|
||||
ifndef HEXAGON_ROOT
|
||||
$(error HEXAGON_ROOT is undefined)
|
||||
endif
|
||||
|
||||
ifndef HEXAGON_CPU_VER
|
||||
$(error HEXAGON_CPU_VER is undefined)
|
||||
endif
|
||||
|
||||
HEXAGON_LPI_BUILD :=
|
||||
|
||||
PLATFORM_ARGS = \
|
||||
-DTF_LITE_MCU_DEBUG_LOG \
|
||||
-DTF_LITE_USE_CTIME \
|
||||
-DHEXAGON_ASM \
|
||||
-DMALLOC_IN_STDLIB \
|
||||
-DPTHREAD_STUBS \
|
||||
-DUSE_PREALLOCATED_BUFFER \
|
||||
-D_HAS_C9X \
|
||||
-DTF_LITE_USE_CTIME \
|
||||
-MMD \
|
||||
-DHEXAGON \
|
||||
-Wall \
|
||||
-Wextra \
|
||||
-Wno-missing-field-initializers \
|
||||
-Wno-sign-compare \
|
||||
-Wno-unused-parameter \
|
||||
-Wno-write-strings \
|
||||
-Wunused-function \
|
||||
-Wno-unused-private-field \
|
||||
-Wvla \
|
||||
-fdata-sections \
|
||||
-ffunction-sections \
|
||||
-fmessage-length=0 \
|
||||
-fno-delete-null-pointer-checks \
|
||||
-fno-exceptions \
|
||||
-fno-register-global-dtors-with-atexit \
|
||||
-fno-rtti \
|
||||
-fno-short-enums \
|
||||
-fno-threadsafe-statics \
|
||||
-fno-unwind-tables \
|
||||
-fno-use-cxa-atexit \
|
||||
-fomit-frame-pointer \
|
||||
-fpermissive \
|
||||
-funsigned-char \
|
||||
-mcpu=$(HEXAGON_CPU_VER) \
|
||||
-m$(HEXAGON_CPU_VER)
|
||||
|
||||
# See http://b/183462077 for more details on why we need -G0 for an LPI build.
|
||||
ifeq ($(HEXAGON_LPI_BUILD), true)
|
||||
PLATFORM_ARGS += -G0
|
||||
endif
|
||||
|
||||
export PATH := $(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/bin:$(PATH)
|
||||
TARGET_TOOLCHAIN_PREFIX := hexagon-
|
||||
CXX_TOOL := clang++
|
||||
CC_TOOL := clang
|
||||
|
||||
CXXFLAGS += $(PLATFORM_ARGS)
|
||||
CCFLAGS += $(PLATFORM_ARGS)
|
||||
LDFLAGS += \
|
||||
-Wl,--gc-sections -lhexagon \
|
||||
$(HEXAGON_ROOT)/$(HEXAGON_TOOL_VER)/Tools/target/hexagon/lib/v66/libstdc++.a
|
||||
|
||||
INCLUDES += \
|
||||
-I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/posix \
|
||||
-I$(HEXAGON_SDK_ROOT)/libs/common/qurt/computev66/include/qurt
|
||||
|
||||
TEST_SCRIPT := tensorflow/lite/micro/testing/test_hexagon_binary.sh
|
||||
|
|
|
|||
BIN
tensorflow/lite/testdata/conv3d_huge_im2col.bin
vendored
Normal file
BIN
tensorflow/lite/testdata/conv3d_huge_im2col.bin
vendored
Normal file
Binary file not shown.
|
|
@ -294,7 +294,7 @@ void BenchmarkPerformanceOptions::CreatePerformanceOptions() {
|
|||
if (!nnapi_accelerators.empty()) {
|
||||
std::vector<std::string> device_names;
|
||||
util::SplitAndParse(nnapi_accelerators, ',', &device_names);
|
||||
for (const auto name : device_names) {
|
||||
for (const auto& name : device_names) {
|
||||
BenchmarkParams params;
|
||||
params.AddParam("use_nnapi", BenchmarkParam::Create<bool>(true));
|
||||
params.AddParam("nnapi_accelerator_name",
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ tensorflow/lite/core/shims/cc/kernels/register.h
|
|||
tensorflow/lite/core/shims/cc/model.h
|
||||
tensorflow/lite/core/shims/cc/model_builder.h
|
||||
tensorflow/lite/core/shims/cc/shims_test_util.h
|
||||
tensorflow/lite/core/shims/cc_library_with_tflite.bzl
|
||||
tensorflow/lite/delegates/gpu/cl/compiled_program_cache_generated.h
|
||||
tensorflow/lite/delegates/gpu/cl/serialization_generated.h
|
||||
tensorflow/lite/delegates/gpu/common/task/serialization_base_generated.h
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
|||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 3, 23)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2021, 3, 24)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
|
|
|||
|
|
@ -1923,13 +1923,16 @@ distribute_py_test(
|
|||
":multi_worker_test_base",
|
||||
":parameter_server_strategy_v2",
|
||||
":sharded_variable",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:extra_py_tests_deps",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:linalg_ops_impl",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
|
|
|
|||
|
|
@ -175,6 +175,9 @@ class GPUCombination(combinations_lib.TestCombination):
|
|||
[required_gpus] + [required_physical_gpus] +
|
||||
[d.required_physical_gpus or 0 for d in distributions] +
|
||||
[d.required_gpus or 0 for d in distributions])
|
||||
number_of_required_physical_gpus = max(
|
||||
[required_physical_gpus] +
|
||||
[d.required_physical_gpus or 0 for d in distributions])
|
||||
|
||||
if (required_physical_gpus and required_gpus):
|
||||
raise ValueError("Only one of `required_physical_gpus`(number of physical"
|
||||
|
|
@ -186,7 +189,8 @@ class GPUCombination(combinations_lib.TestCombination):
|
|||
and context.num_gpus() < number_of_required_gpus):
|
||||
return (False, ("Only {} of {} required GPUs are available.".format(
|
||||
context.num_gpus(), number_of_required_gpus)))
|
||||
elif required_physical_gpus > len(config.list_physical_devices("GPU")):
|
||||
elif number_of_required_physical_gpus > len(
|
||||
config.list_physical_devices("GPU")):
|
||||
return (False,
|
||||
("Only {} of {} required physical GPUs are available.".format(
|
||||
config.list_physical_devices("GPU"), required_physical_gpus)))
|
||||
|
|
|
|||
|
|
@ -640,7 +640,11 @@ class WorkerPreemptionHandler(object):
|
|||
|
||||
def _validate_preemption_failure(self, e):
|
||||
"""Validates that the given exception represents worker preemption."""
|
||||
if _is_worker_failure(e):
|
||||
|
||||
# Only categorize the failure as a worker preemption if the cancellation
|
||||
# manager did not attempt to cancel the blocking operations.
|
||||
if _is_worker_failure(e) and (
|
||||
not self._cluster._closure_queue._cancellation_mgr.is_cancelled): # pylint: disable=protected-access
|
||||
return
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class ParameterServerStrategyV2Test(test.TestCase):
|
|||
strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
|
||||
self.cluster_resolver)
|
||||
strategy.extended._allow_run_without_coordinator = True
|
||||
dataset = dataset_ops.DatasetV2.range(3)
|
||||
dataset = dataset_ops.DatasetV2.range(15)
|
||||
with strategy.scope():
|
||||
v = variables.Variable(1, dtype=dtypes.int64)
|
||||
|
||||
|
|
|
|||
|
|
@ -299,15 +299,16 @@ class ShardedVariableMixin(trackable.Trackable):
|
|||
raise ValueError(
|
||||
'All `Variables`s must have the same shapes except for the first '
|
||||
'axis, found {}'.format([v.shape for v in variables]))
|
||||
first_dim = sum(int(v.shape[0]) for v in variables)
|
||||
self._shape = tensor_shape.TensorShape([first_dim] + first_var.shape[1:])
|
||||
first_dim = sum(int(v.shape.as_list()[0]) for v in variables)
|
||||
self._shape = tensor_shape.TensorShape([first_dim] +
|
||||
first_var.shape.as_list()[1:])
|
||||
self._var_offsets = [
|
||||
[0 for _ in range(len(first_var.shape))] for _ in range(len(variables))
|
||||
]
|
||||
for i in range(1, len(variables)):
|
||||
# Always partition on the first axis. Offsets on other axes are 0.
|
||||
self._var_offsets[i][0] += (
|
||||
self._var_offsets[i - 1][0] + variables[i - 1].shape[0])
|
||||
self._var_offsets[i - 1][0] + variables[i - 1].shape.as_list()[0])
|
||||
|
||||
save_slice_info = [v._get_save_slice_info() for v in variables] # pylint: disable=protected-access
|
||||
if any(slice_info is not None for slice_info in save_slice_info):
|
||||
|
|
|
|||
|
|
@ -12,15 +12,11 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Implementation of the Keras API meant to be a high-level API for TensorFlow.
|
||||
"""Implementation of the Keras API, the high-level API of TensorFlow.
|
||||
|
||||
Detailed documentation and user guides are available at
|
||||
[tensorflow.org](https://www.tensorflow.org/guide/keras).
|
||||
[keras.io](https://keras.io).
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras import distribute
|
||||
|
|
@ -28,7 +24,6 @@ from tensorflow.python.keras import distribute
|
|||
# See b/110718070#comment18 for more details about this import.
|
||||
from tensorflow.python.keras import models
|
||||
|
||||
|
||||
from tensorflow.python.keras.engine.input_layer import Input
|
||||
from tensorflow.python.keras.engine.sequential import Sequential
|
||||
from tensorflow.python.keras.engine.training import Model
|
||||
|
|
|
|||
|
|
@ -8,25 +8,20 @@
|
|||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY backendIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Built-in activation functions."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.python.keras.layers import advanced_activations
|
||||
|
||||
# b/123041942
|
||||
# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
|
||||
|
|
@ -34,7 +29,6 @@ from tensorflow.python.keras.layers import advanced_activations
|
|||
# internal method name is returned in serialization. This results in errors in
|
||||
# model exporting and loading as Keras can't find any activation function with
|
||||
# the name of `softmax_v2`.
|
||||
|
||||
# This dict maps the activation function name from its v2 version to its
|
||||
# canonical name.
|
||||
_TF_ACTIVATIONS_V2 = {
|
||||
|
|
@ -146,7 +140,7 @@ def elu(x, alpha=1.0):
|
|||
[Fast and Accurate Deep Network Learning by Exponential Linear Units
|
||||
(ELUs) (Clevert et al, 2016)](https://arxiv.org/abs/1511.07289)
|
||||
"""
|
||||
return K.elu(x, alpha)
|
||||
return backend.elu(x, alpha)
|
||||
|
||||
|
||||
@keras_export('keras.activations.selu')
|
||||
|
|
@ -198,7 +192,7 @@ def selu(x):
|
|||
`tf.keras.layers.AlphaDropout` (not regular dropout).
|
||||
|
||||
References:
|
||||
- [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)
|
||||
- [backendlambauer et al., 2017](https://arxiv.org/abs/1706.02515)
|
||||
"""
|
||||
return nn.selu(x)
|
||||
|
||||
|
|
@ -315,7 +309,7 @@ def relu(x, alpha=0., max_value=None, threshold=0):
|
|||
transformed by the relu activation function.
|
||||
Tensor will be of the same shape and dtype of input `x`.
|
||||
"""
|
||||
return K.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
|
||||
return backend.relu(x, alpha=alpha, max_value=max_value, threshold=threshold)
|
||||
|
||||
|
||||
@keras_export('keras.activations.gelu', v1=[])
|
||||
|
|
@ -458,7 +452,7 @@ def hard_sigmoid(x):
|
|||
- `if x > 2.5: return 1`
|
||||
- `if -2.5 <= x <= 2.5: return 0.2 * x + 0.5`
|
||||
"""
|
||||
return K.hard_sigmoid(x)
|
||||
return backend.hard_sigmoid(x)
|
||||
|
||||
|
||||
@keras_export('keras.activations.linear')
|
||||
|
|
@ -588,7 +582,7 @@ def get(identifier):
|
|||
"""
|
||||
if identifier is None:
|
||||
return linear
|
||||
if isinstance(identifier, six.string_types):
|
||||
if isinstance(identifier, str):
|
||||
identifier = str(identifier)
|
||||
return deserialize(identifier)
|
||||
elif isinstance(identifier, dict):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user