Enable pass to hoist writes to replicate invariant resources outside tf_device.replicate op.

Enabling this pass in the TF-XLA MLIR bridge reduces the memory requirement on host as it does not need to allocate buffer memory for output from all device replicas.

PiperOrigin-RevId: 372858954
Change-Id: Ib94d09efe8468b4293e715f84cbdc8964af4155a
This commit is contained in:
Prakalp Srivastava 2021-05-09 22:18:48 -07:00 committed by TensorFlower Gardener
parent 0daf99950c
commit 6bd9c93fc9

View File

@ -158,6 +158,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
pm.addPass(createSymbolDCEPass());
pm.addNestedPass<FuncOp>(TFDevice::CreateReplicateInvariantOpHoistingPass());
pm.addNestedPass<FuncOp>(CreateTPUMergeVariablesWithExecutePass());
pm.addNestedPass<FuncOp>(
TF::CreateHoistReplicateInvariantResourceWritesPass());
pm.addNestedPass<FuncOp>(CreateTPUColocateCompositeResourceOps());
pm.addPass(CreateTPUVariableReformattingPass());
pm.addPass(TF::CreateTFRegionControlFlowToFunctional());