mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Use applyPartialConversion instead of applyPatternsAndFoldGreedily
since the pattern matchings for converting TF -> MHLO uniform quantized graphs are based on the OpConversionPattern. PiperOrigin-RevId: 540439199
This commit is contained in:
parent
b6a69dda3a
commit
76f4e0a8f9
|
|
@ -472,6 +472,7 @@ cc_library(
|
|||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@stablehlo//:chlo_ops",
|
||||
],
|
||||
# Alwayslink is required for registering the MLIR passes.
|
||||
# TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ limitations under the License.
|
|||
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
|
||||
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
|
||||
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
|
||||
|
|
@ -55,6 +57,7 @@ class ConvertTFQuantOpsToMHLOPass
|
|||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<TF::TensorFlowDialect>();
|
||||
registry.insert<mhlo::MhloDialect>();
|
||||
registry.insert<chlo::ChloDialect>();
|
||||
registry.insert<tf_type::TFTypeDialect>();
|
||||
registry.insert<quant::QuantizationDialect>();
|
||||
registry.insert<quantfork::QuantizationForkDialect>();
|
||||
|
|
@ -68,10 +71,19 @@ static PassRegistration<ConvertTFQuantOpsToMHLOPass> pass;
|
|||
void ConvertTFQuantOpsToMHLOPass::runOnOperation() {
|
||||
MLIRContext *ctx = &getContext();
|
||||
func::FuncOp func = getOperation();
|
||||
ConversionTarget target(*ctx);
|
||||
target.addLegalDialect<TF::TensorFlowDialect, mhlo::MhloDialect,
|
||||
chlo::ChloDialect>();
|
||||
target.addIllegalOp<
|
||||
TF::UniformQuantizeOp, TF::UniformRequantizeOp, TF::UniformDequantizeOp,
|
||||
TF::UniformQuantizedDotOp, TF::UniformQuantizedDotHybridOp,
|
||||
TF::UniformQuantizedConvolutionOp,
|
||||
TF::UniformQuantizedConvolutionHybridOp, TF::UniformQuantizedAddOp,
|
||||
TF::UniformQuantizedClipByValueOp>();
|
||||
|
||||
RewritePatternSet patterns(ctx);
|
||||
mhlo::PopulateLegalizeTfQuantizationPatterns(ctx, &patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,53 @@ func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> {
|
|||
}
|
||||
|
||||
// CHECK: func @quantized_matmul_fn
|
||||
// CHECK: "tf.AddV2"
|
||||
// CHECK: mhlo.constant
|
||||
// CHECK-SAME{LITERAL}: dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>
|
||||
// CHECK-NEXT: "tf.AddV2"
|
||||
// CHECK-NEXT: "mhlo.dot"(%1, %0) : (tensor<*xf32>, tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:3>>) -> tensor<*xf32>
|
||||
// CHECK: "mhlo.dot"
|
||||
// CHECK-SAME: (tensor<*xf32>, tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:3>>) -> tensor<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @uniform_quantized_add
|
||||
func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> () {
|
||||
%input_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%input_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
|
||||
// tensor_proto that points to dense<127> of type !tf_type.qint32.
|
||||
%bias = "tf.Const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F51494E5433322074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20225C3137375C3030305C3030305C30303022"> : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32>
|
||||
%bias_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%bias_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
|
||||
|
||||
%output_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
|
||||
%output_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
|
||||
|
||||
// CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>
|
||||
// CHECK-DAG: %[[RHS:.*]] = mhlo.constant()
|
||||
// CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform<i32:f32, 2.000000e+00:4>>
|
||||
// CHECK: chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} :
|
||||
// CHECK-SAME: (tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>, tensor<2x!quant.uniform<i32:f32, 2.000000e+00:4>>)
|
||||
// CHECK-SAME: -> tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>
|
||||
|
||||
%0 = "tf.UniformQuantize"(%input, %input_scales, %input_zps) {
|
||||
quantization_axis = -1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64
|
||||
} : (tensor<3x2xf32>, tensor<f32>, tensor<i32>) -> tensor<3x2x!tf_type.qint32>
|
||||
%1 = "tf.UniformQuantizedAdd"(
|
||||
%0, %bias,
|
||||
%input_scales, %input_zps,
|
||||
%bias_scales, %bias_zps,
|
||||
%output_scales, %output_zps) {
|
||||
lhs_quantization_axis = -1 : i64,
|
||||
lhs_quantization_min_val = -2147483648 : i64,
|
||||
lhs_quantization_max_val = 2147483647 : i64,
|
||||
rhs_quantization_axis = -1 : i64,
|
||||
rhs_quantization_min_val = -2147483648 : i64,
|
||||
rhs_quantization_max_val = 2147483647 : i64,
|
||||
output_quantization_axis = -1 : i64,
|
||||
output_quantization_min_val = -2147483648 : i64,
|
||||
output_quantization_max_val = 2147483647 : i64} : (
|
||||
tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>,
|
||||
tensor<f32>, tensor<i32>,
|
||||
tensor<f32>, tensor<i32>,
|
||||
tensor<f32>, tensor<i32>) -> tensor<3x2x!tf_type.qint32>
|
||||
func.return
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user