diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 7b8ae474941..1963931b497 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -107,6 +107,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { } void CreateTPUBridgePipelineV1(OpPassManager &pm) { + pm.addPass(TF::CreateTFShapeInferencePass()); // For V1 compatibility, we process a module where the graph does not have // feeds and fetched. We extract first the TPU computation in a submodule, // where it'll be in a function with args and returned values, much more like