Use applyPartialConversion instead of applyPatternsAndFoldGreedily

since the pattern matchings for converting TF -> MHLO uniform quantized graphs
are based on the OpConversionPattern.

PiperOrigin-RevId: 540439199
This commit is contained in:
Jaesung Chung 2023-06-14 18:36:00 -07:00 committed by TensorFlower Gardener
parent b6a69dda3a
commit 76f4e0a8f9
3 changed files with 62 additions and 3 deletions

View File

@ -472,6 +472,7 @@ cc_library(
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
"@stablehlo//:chlo_ops",
],
# Alwayslink is required for registering the MLIR passes.
# TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat.

View File

@ -19,7 +19,9 @@ limitations under the License.
#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "stablehlo/dialect/ChloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
@ -55,6 +57,7 @@ class ConvertTFQuantOpsToMHLOPass
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TF::TensorFlowDialect>();
registry.insert<mhlo::MhloDialect>();
registry.insert<chlo::ChloDialect>();
registry.insert<tf_type::TFTypeDialect>();
registry.insert<quant::QuantizationDialect>();
registry.insert<quantfork::QuantizationForkDialect>();
@ -68,10 +71,19 @@ static PassRegistration<ConvertTFQuantOpsToMHLOPass> pass;
void ConvertTFQuantOpsToMHLOPass::runOnOperation() {
MLIRContext *ctx = &getContext();
func::FuncOp func = getOperation();
ConversionTarget target(*ctx);
target.addLegalDialect<TF::TensorFlowDialect, mhlo::MhloDialect,
chlo::ChloDialect>();
target.addIllegalOp<
TF::UniformQuantizeOp, TF::UniformRequantizeOp, TF::UniformDequantizeOp,
TF::UniformQuantizedDotOp, TF::UniformQuantizedDotHybridOp,
TF::UniformQuantizedConvolutionOp,
TF::UniformQuantizedConvolutionHybridOp, TF::UniformQuantizedAddOp,
TF::UniformQuantizedClipByValueOp>();
RewritePatternSet patterns(ctx);
mhlo::PopulateLegalizeTfQuantizationPatterns(ctx, &patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) {
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
signalPassFailure();
}
}

View File

@ -11,7 +11,53 @@ func.func @quantized_matmul_fn(%input: tensor<*xf32>) -> tensor<*xf32> {
}
// CHECK: func @quantized_matmul_fn
// CHECK: "tf.AddV2"
// CHECK: mhlo.constant
// CHECK-SAME{LITERAL}: dense<[[1, 2], [3, 4]]> : tensor<2x2xi8>
// CHECK-NEXT: "tf.AddV2"
// CHECK-NEXT: "mhlo.dot"(%1, %0) : (tensor<*xf32>, tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:3>>) -> tensor<*xf32>
// CHECK: "mhlo.dot"
// CHECK-SAME: (tensor<*xf32>, tensor<2x2x!quant.uniform<i8:f32, 1.000000e+00:3>>) -> tensor<*xf32>
// -----
// CHECK-LABEL: func @uniform_quantized_add
func.func @uniform_quantized_add(%input: tensor<3x2xf32>) -> () {
%input_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%input_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
// tensor_proto that points to dense<127> of type !tf_type.qint32.
%bias = "tf.Const"() { value = #tf_type<tensor_proto : "0x746674656E736F722464747970653A2044545F51494E5433322074656E736F725F7368617065207B207D2074656E736F725F636F6E74656E743A20225C3137375C3030305C3030305C30303022"> : tensor<2x!tf_type.qint32> } : () -> tensor<2x!tf_type.qint32>
%bias_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%bias_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
%output_scales = "tf.Const"() { value = dense<2.0> : tensor<f32> } : () -> tensor<f32>
%output_zps = "tf.Const"() { value = dense<4> : tensor<i32> } : () -> tensor<i32>
// CHECK-DAG: %[[LHS:.*]] = mhlo.uniform_quantize %arg0 : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>
// CHECK-DAG: %[[RHS:.*]] = mhlo.constant()
// CHECK-SAME{LITERAL}: {value = dense<127> : tensor<2xi32>} : () -> tensor<2x!quant.uniform<i32:f32, 2.000000e+00:4>>
// CHECK: chlo.broadcast_add %[[LHS]], %[[RHS]] {broadcast_dimensions = dense<1> : tensor<1xi64>} :
// CHECK-SAME: (tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>, tensor<2x!quant.uniform<i32:f32, 2.000000e+00:4>>)
// CHECK-SAME: -> tensor<3x2x!quant.uniform<i32:f32, 2.000000e+00:4>>
%0 = "tf.UniformQuantize"(%input, %input_scales, %input_zps) {
quantization_axis = -1 : i64, quantization_min_val = -2147483648 : i64, quantization_max_val = 2147483647 : i64
} : (tensor<3x2xf32>, tensor<f32>, tensor<i32>) -> tensor<3x2x!tf_type.qint32>
%1 = "tf.UniformQuantizedAdd"(
%0, %bias,
%input_scales, %input_zps,
%bias_scales, %bias_zps,
%output_scales, %output_zps) {
lhs_quantization_axis = -1 : i64,
lhs_quantization_min_val = -2147483648 : i64,
lhs_quantization_max_val = 2147483647 : i64,
rhs_quantization_axis = -1 : i64,
rhs_quantization_min_val = -2147483648 : i64,
rhs_quantization_max_val = 2147483647 : i64,
output_quantization_axis = -1 : i64,
output_quantization_min_val = -2147483648 : i64,
output_quantization_max_val = 2147483647 : i64} : (
tensor<3x2x!tf_type.qint32>, tensor<2x!tf_type.qint32>,
tensor<f32>, tensor<i32>,
tensor<f32>, tensor<i32>,
tensor<f32>, tensor<i32>) -> tensor<3x2x!tf_type.qint32>
func.return
}