Clean up Embedding op side effect handling

- now we have two Embedding side effects, read and write
- now dependencies between EnqueueTPUEmbedding ops with same device ordinal are
  properly modeled
- we now finally don't have any Embedding-specific code left in side effect
  analysis
- introduced new `TF_MustExecute` trait that avoids pruning of an op; this is
  useful for side-effecting ops that don't produce any output and don't have
  dependencies to/from other ops
- for ops that just used `TF_TPUEmbeddingSideEffect` to avoid pruning, use new
  `TF_MustExecute` trait instead
- in contrast to the old `TF_TPUEmbeddingSideEffect`, `TF_MustExecute` avoids
  pruning independent of reachability (see new graph pruning test)

PiperOrigin-RevId: 413175982
Change-Id: I7b65c7a0e8a17b8a1683a0e01d1fd0614f7ac95a
This commit is contained in:
Michael Gester 2021-11-30 09:50:34 -08:00 committed by TensorFlower Gardener
parent 5c64255f32
commit ae7626fbc4
12 changed files with 260 additions and 98 deletions

View File

@ -1248,6 +1248,7 @@ cc_library(
":tensorflow_analysis",
":tensorflow_ops",
":tensorflow_optimize_inc_gen",
":tensorflow_side_effects",
":tensorflow_types",
":tf_data_optimization",
":tf_legalize_hlo",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
@ -155,11 +156,7 @@ SideEffects GetSideEffectsFromEffectInstance(
const MemoryEffects::EffectInstance& effect_instance, Operation* op) {
mlir::SideEffects::Effect* effect = effect_instance.getEffect();
SideEffects side_effects;
if (llvm::isa<ResourceEffects::TPUEmbedding>(effect_instance.getResource())) {
// TODO(mgester) This hack can be removed once b/196857154 is fixed.
// See definition of `TF_TPUEmbeddingSideEffect` for more details.
side_effects.SetRead();
} else if (isa<MemoryEffects::Allocate>(effect)) {
if (isa<MemoryEffects::Allocate>(effect)) {
side_effects.SetAlloc();
} else if (isa<MemoryEffects::Free>(effect)) {
side_effects.SetFree();
@ -356,11 +353,20 @@ class OpSideEffectCollector {
// We handle value-based side effects for which we can use resource
// alias analysis at a different place, skip here.
if (ShouldUseResourceAliasAnalysis(effect)) continue;
if (llvm::isa<ResourceEffects::MustExecute>(effect.getResource()))
// We have this fake resource to avoid that certain ops are considered
// dead or get pruned, ignore it for side effect analysis.
continue;
// Add side effects for op resource ID.
int64_t instance_id = -1;
SideEffects side_effects(GetSideEffectsFromEffectInstance(effect, op));
if (auto resource_instance_op =
dyn_cast<GetResourceInstanceInterface>(op)) {
instance_id = resource_instance_op.GetResourceInstanceId();
}
ResourceId resource_id =
GetOpResourceId(effect.getResource()->getResourceID());
GetOpResourceId(effect.getResource()->getResourceID(), instance_id);
side_effects.SetResourceId(resource_id);
UpdateSideEffectsByResourceId(side_effects,
side_effects_by_resource_id);
@ -368,10 +374,11 @@ class OpSideEffectCollector {
}
}
// Get internal op resource ID from MLIR type ID.
ResourceId GetOpResourceId(TypeID type_id) {
// Get internal op resource ID from MLIR type ID and instance ID.
ResourceId GetOpResourceId(TypeID type_id, int64_t instance_id) {
auto emplace_result =
type_id_to_op_resource_id_.try_emplace(type_id, next_op_resource_id_);
type_instance_ids_to_op_resource_id_.try_emplace(
std::make_pair(type_id, instance_id), next_op_resource_id_);
// Increment type ID if we have encountered a new resource type.
if (emplace_result.second) ++next_op_resource_id_;
return emplace_result.first->getSecond();
@ -385,9 +392,10 @@ class OpSideEffectCollector {
// Next available ID for op-based resources (resources not handled by resource
// alias analysis).
ResourceId next_op_resource_id_ = kMaxResourceId + 1;
// Maps MLIR type IDs for resource types to internal IDs for op-based
// resources. Also see comment above.
llvm::SmallDenseMap<TypeID, ResourceId> type_id_to_op_resource_id_;
// Maps (type ID, instance ID) pairs to internal IDs for op-based resources.
// Also see comment above.
llvm::SmallDenseMap<std::pair<TypeID, int64_t>, ResourceId>
type_instance_ids_to_op_resource_id_;
// Used for faster callable resolution.
SymbolTableCollection symbol_table_collection_;
// Collect all op-based side effects here.

View File

@ -4065,7 +4065,7 @@ This operation creates a tensor of `shape` and `dtype`.
let hasFolder = 1;
}
def TF_EnqueueTPUEmbeddingArbitraryTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingArbitraryTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingArbitraryTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingArbitraryTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
let summary = [{
Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
}];
@ -4117,7 +4117,7 @@ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.}]>:$mode_
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
}
def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingBatchOp : TF_Op<"EnqueueTPUEmbeddingBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_TPUEmbeddingWriteEffect]> {
let summary = [{
An op that enqueues a list of input batch tensors to TPUEmbedding.
}];
@ -4141,7 +4141,7 @@ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.}]>:$mode_
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_EnqueueTPUEmbeddingIntegerBatchOp : TF_Op<"EnqueueTPUEmbeddingIntegerBatch", [TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingIntegerBatchOp : TF_Op<"EnqueueTPUEmbeddingIntegerBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, TF_TPUEmbeddingWriteEffect]> {
let summary = [{
An op that enqueues a list of input batch tensors to TPUEmbedding.
}];
@ -4162,7 +4162,7 @@ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.}]>:$mode_
TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>;
}
def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingRaggedTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingRaggedTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
let summary = "Eases the porting of code that uses tf.nn.embedding_lookup().";
let description = [{
@ -4207,7 +4207,7 @@ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.}]>:$mode_
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
}
def TF_EnqueueTPUEmbeddingSparseBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingSparseBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
let summary = [{
An op that enqueues TPUEmbedding input indices from a SparseTensor.
}];
@ -4250,7 +4250,7 @@ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.}]>:$mode_
TF_DerivedOperandTypeAttr T3 = TF_DerivedOperandTypeAttr<2>;
}
def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [SameVariadicOperandSize, TF_TPUEmbeddingSideEffect]> {
def TF_EnqueueTPUEmbeddingSparseTensorBatchOp : TF_Op<"EnqueueTPUEmbeddingSparseTensorBatch", [DeclareOpInterfaceMethods<TF_GetResourceInstanceInterface>, SameVariadicOperandSize, TF_TPUEmbeddingWriteEffect]> {
let summary = [{
Eases the porting of code that uses tf.nn.embedding_lookup_sparse().
}];
@ -6948,7 +6948,7 @@ idx ==> [1, 3, 5]
TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>;
}
def TF_LoadTPUEmbeddingADAMParametersOp : TF_Op<"LoadTPUEmbeddingADAMParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingADAMParametersOp : TF_Op<"LoadTPUEmbeddingADAMParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load ADAM embedding parameters.";
let description = [{
@ -6974,7 +6974,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingADAMParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingADAMParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -6993,7 +6993,7 @@ def TF_LoadTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingA
let results = (outs);
}
def TF_LoadTPUEmbeddingAdadeltaParametersOp : TF_Op<"LoadTPUEmbeddingAdadeltaParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingAdadeltaParametersOp : TF_Op<"LoadTPUEmbeddingAdadeltaParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load Adadelta embedding parameters.";
let description = [{
@ -7019,7 +7019,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7038,7 +7038,7 @@ def TF_LoadTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbedd
let results = (outs);
}
def TF_LoadTPUEmbeddingAdagradParametersOp : TF_Op<"LoadTPUEmbeddingAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingAdagradParametersOp : TF_Op<"LoadTPUEmbeddingAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load Adagrad embedding parameters.";
let description = [{
@ -7063,7 +7063,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7081,7 +7081,7 @@ def TF_LoadTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddi
let results = (outs);
}
def TF_LoadTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingCenteredRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingCenteredRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load centered RMSProp embedding parameters.";
let description = [{
@ -7108,7 +7108,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingFTRLParametersOp : TF_Op<"LoadTPUEmbeddingFTRLParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingFTRLParametersOp : TF_Op<"LoadTPUEmbeddingFTRLParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load FTRL embedding parameters.";
let description = [{
@ -7134,7 +7134,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingFTRLParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingFTRLParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7153,7 +7153,7 @@ def TF_LoadTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingF
let results = (outs);
}
def TF_LoadTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"LoadTPUEmbeddingMDLAdagradLightParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"LoadTPUEmbeddingMDLAdagradLightParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load MDL Adagrad Light embedding parameters.";
let description = [{
@ -7180,7 +7180,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingMomentumParametersOp : TF_Op<"LoadTPUEmbeddingMomentumParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingMomentumParametersOp : TF_Op<"LoadTPUEmbeddingMomentumParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load Momentum embedding parameters.";
let description = [{
@ -7205,7 +7205,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingMomentumParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingMomentumParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7223,7 +7223,7 @@ def TF_LoadTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbedd
let results = (outs);
}
def TF_LoadTPUEmbeddingProximalAdagradParametersOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingProximalAdagradParametersOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load proximal Adagrad embedding parameters.";
let description = [{
@ -7248,7 +7248,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7266,7 +7266,7 @@ def TF_LoadTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"LoadTP
let results = (outs);
}
def TF_LoadTPUEmbeddingProximalYogiParametersOp : TF_Op<"LoadTPUEmbeddingProximalYogiParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingProximalYogiParametersOp : TF_Op<"LoadTPUEmbeddingProximalYogiParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7284,7 +7284,7 @@ def TF_LoadTPUEmbeddingProximalYogiParametersOp : TF_Op<"LoadTPUEmbeddingProxima
let results = (outs);
}
def TF_LoadTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7303,7 +7303,7 @@ def TF_LoadTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"LoadTPUEm
let results = (outs);
}
def TF_LoadTPUEmbeddingRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingRMSPropParametersOp : TF_Op<"LoadTPUEmbeddingRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load RMSProp embedding parameters.";
let description = [{
@ -7329,7 +7329,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -7348,7 +7348,7 @@ def TF_LoadTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddi
let results = (outs);
}
def TF_LoadTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Load SGD embedding parameters.";
let description = [{
@ -7372,7 +7372,7 @@ executed.
let results = (outs);
}
def TF_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"LoadTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -11548,7 +11548,7 @@ def TF_RecvOp : TF_Op<"Recv", []> {
TF_DerivedResultTypeAttr tensor_type = TF_DerivedResultTypeAttr<0>;
}
def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
def TF_RecvTPUEmbeddingActivationsOp : TF_Op<"RecvTPUEmbeddingActivations", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "An op that receives embedding activations on the TPU.";
let description = [{
@ -12939,7 +12939,7 @@ checkpoint directly.}]>:$tensors
TF_DerivedResultTypeListAttr dtypes = TF_DerivedResultTypeListAttr<0>;
}
def TF_RetrieveTPUEmbeddingADAMParametersOp : TF_Op<"RetrieveTPUEmbeddingADAMParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingADAMParametersOp : TF_Op<"RetrieveTPUEmbeddingADAMParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve ADAM embedding parameters.";
let description = [{
@ -12964,7 +12964,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingADAMParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingADAMParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -12983,7 +12983,7 @@ def TF_RetrieveTPUEmbeddingADAMParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEm
);
}
def TF_RetrieveTPUEmbeddingAdadeltaParametersOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingAdadeltaParametersOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve Adadelta embedding parameters.";
let description = [{
@ -13008,7 +13008,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13027,7 +13027,7 @@ def TF_RetrieveTPUEmbeddingAdadeltaParametersGradAccumDebugOp : TF_Op<"RetrieveT
);
}
def TF_RetrieveTPUEmbeddingAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve Adagrad embedding parameters.";
let description = [{
@ -13051,7 +13051,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13069,7 +13069,7 @@ def TF_RetrieveTPUEmbeddingAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTP
);
}
def TF_RetrieveTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingCenteredRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingCenteredRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingCenteredRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve centered RMSProp embedding parameters.";
let description = [{
@ -13095,7 +13095,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingFTRLParametersOp : TF_Op<"RetrieveTPUEmbeddingFTRLParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingFTRLParametersOp : TF_Op<"RetrieveTPUEmbeddingFTRLParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve FTRL embedding parameters.";
let description = [{
@ -13120,7 +13120,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingFTRLParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13139,7 +13139,7 @@ def TF_RetrieveTPUEmbeddingFTRLParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEm
);
}
def TF_RetrieveTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"RetrieveTPUEmbeddingMDLAdagradLightParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingMDLAdagradLightParametersOp : TF_Op<"RetrieveTPUEmbeddingMDLAdagradLightParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve MDL Adagrad Light embedding parameters.";
let description = [{
@ -13165,7 +13165,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingMomentumParametersOp : TF_Op<"RetrieveTPUEmbeddingMomentumParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingMomentumParametersOp : TF_Op<"RetrieveTPUEmbeddingMomentumParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve Momentum embedding parameters.";
let description = [{
@ -13189,7 +13189,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingMomentumParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13207,7 +13207,7 @@ def TF_RetrieveTPUEmbeddingMomentumParametersGradAccumDebugOp : TF_Op<"RetrieveT
);
}
def TF_RetrieveTPUEmbeddingProximalAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingProximalAdagradParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve proximal Adagrad embedding parameters.";
let description = [{
@ -13231,7 +13231,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13249,7 +13249,7 @@ def TF_RetrieveTPUEmbeddingProximalAdagradParametersGradAccumDebugOp : TF_Op<"Re
);
}
def TF_RetrieveTPUEmbeddingProximalYogiParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingProximalYogiParametersOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13267,7 +13267,7 @@ def TF_RetrieveTPUEmbeddingProximalYogiParametersOp : TF_Op<"RetrieveTPUEmbeddin
);
}
def TF_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13286,7 +13286,7 @@ def TF_RetrieveTPUEmbeddingProximalYogiParametersGradAccumDebugOp : TF_Op<"Retri
);
}
def TF_RetrieveTPUEmbeddingRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingRMSPropParametersOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve RMSProp embedding parameters.";
let description = [{
@ -13311,7 +13311,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingRMSPropParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -13330,7 +13330,7 @@ def TF_RetrieveTPUEmbeddingRMSPropParametersGradAccumDebugOp : TF_Op<"RetrieveTP
);
}
def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParameters", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParameters", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Retrieve SGD embedding parameters.";
let description = [{
@ -13353,7 +13353,7 @@ used to retrieve updated parameters before saving a checkpoint.
);
}
def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_TPUEmbeddingSideEffect]> {
def TF_RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebugOp : TF_Op<"RetrieveTPUEmbeddingStochasticGradientDescentParametersGradAccumDebug", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "";
let arguments = (ins
@ -14407,7 +14407,7 @@ def TF_SendOp : TF_Op<"Send", []> {
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "Performs gradient updates of embedding tables.";
let arguments = (ins
@ -20173,7 +20173,7 @@ def TF__ListToArrayOp : TF_Op<"_ListToArray", [NoSideEffect]> {
TF_DerivedResultTypeAttr T = TF_DerivedResultTypeAttr<0>;
}
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", [TF_TPUEmbeddingSideEffect]> {
def TF__RecvTPUEmbeddingActivationsOp : TF_Op<"_RecvTPUEmbeddingActivations", [TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "An op that receives embeddng activations on the TPU.";
let description = [{
@ -20225,7 +20225,7 @@ of the embedding lookup operation.
);
}
def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_TPUEmbeddingSideEffect]> {
def TF__SendTPUEmbeddingGradientsOp : TF_Op<"_SendTPUEmbeddingGradients", [AttrSizedOperandSegments, TF_MustExecute, TF_TPUEmbeddingReadEffect]> {
let summary = "An op that performs gradient updates of embedding tables.";
let description = [{

View File

@ -165,6 +165,8 @@ def TF_TPUEmbeddingResource : TF_ResourceBase<"TPUEmbedding">;
def TF_GeneratorOpResource : TF_ResourceBase<"GeneratorOp">;
def TF_SendRecvResource : TF_ResourceBase<"SendRecv">;
def TF_TPUCompileExecuteResource : TF_ResourceBase<"TPUCompileExecute">;
// Fake resource, see `TF_MustExecute` below.
def TF_MustExecuteResource : TF_ResourceBase<"MustExecute">;
// Value-based side effects
//
@ -214,17 +216,21 @@ def TF_DatasetIteratorFree : MemFree<TF_DatasetIteratorResource>;
// effecting ops. Note that for `read` effects ops might be pruned if nothing
// depends on them.
def TF_GeneratorOpSideEffect : MemoryEffects<[MemWrite<TF_GeneratorOpResource>]>;
// Note: We actually want a `read` effect here but then some ops with this
// effect are considered dead and are deleted which is not desired (see
// b/195782952).
// Therefore, we use a `write` effect + special handling in side effect
// analysis. Once we have proper dependencies that avoid deletion (see
// b/196857154), or once MLIR supports a trait to mark an op as not dead, this
// hack can be removed.
def TF_TPUEmbeddingSideEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
def TF_TPUEmbeddingWriteEffect : MemoryEffects<[MemWrite<TF_TPUEmbeddingResource>]>;
def TF_TPUEmbeddingReadEffect : MemoryEffects<[MemRead<TF_TPUEmbeddingResource>]>;
def TF_SendRecvSideEffect : MemoryEffects<[MemWrite<TF_SendRecvResource>]>;
def TF_TPUCompileExecuteSideEffect : MemoryEffects<[MemWrite<TF_TPUCompileExecuteResource>]>;
// Trait for enforcing that a side-effecting op is executed, even if it would be
// considered dead by MLIR (see b/195782952).
// The trait is implemented as a write effect for a fake resource which is
// ignored by side effect analysis, so it does not affect execution order
// constraints and control dependencies at all (for example, multiple ops with
// this trait do not have to execute in order).
def TF_MustExecute : MemoryEffects<[MemWrite<TF_MustExecuteResource>]>;
//===----------------------------------------------------------------------===//
// TensorFlow op definitions
//===----------------------------------------------------------------------===//

View File

@ -131,4 +131,23 @@ def TF_ResourceHandleAllocatorInterface : OpInterface<"ResourceHandleAllocatorIn
];
}
def TF_GetResourceInstanceInterface : OpInterface<"GetResourceInstanceInterface"> {
let description = [{Returns an integer corresponding to the resource instance
accessed by this op}];
let methods = [
InterfaceMethod<
/*desc=*/[{Returns an integer corresponding to the resource instance
accessed by this op. The implementation must guarantee that the
mapping between resource instances and integers is bijective,
i.e., two op instances should return the same integer if and
only if they access the same resource. The interface should
only be used for ops that access exactly one resource.}],
/*retTy=*/"int64_t",
/*methodName=*/"GetResourceInstanceId",
/*args=*/(ins)
>,
];
}
#endif // TF_OP_INTERFACES

View File

@ -1564,6 +1564,13 @@ def TF__InternalTestNonResourceValueSideEffects_ : TF_Op<"_InternalTestNonResour
let results = (outs);
}
def TF__InternalTestMustExecuteTrait_ : TF_Op<"_InternalTestMustExecuteTrait_", [TF_MustExecute]> {
let summary = "Internal op for testing only";
let arguments = (ins);
let results = (outs);
}
def TF_SetStaticDimensionBoundsOp : TF_Op<"SetStaticDimensionBounds", []> {
let summary = "Op used to indicate to the compiler and runtime the static bounds of a tensor.";
let description = [{

View File

@ -2299,6 +2299,37 @@ static LogicalResult Verify(EmptyTensorListOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// EnqueueTPUEmbedding ops
//===----------------------------------------------------------------------===//
// For EnqueueTPUEmbedding ops the device ordinal corresponds to the resource
// instance.
int64_t EnqueueTPUEmbeddingArbitraryTensorBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
int64_t EnqueueTPUEmbeddingBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
int64_t EnqueueTPUEmbeddingIntegerBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
int64_t EnqueueTPUEmbeddingRaggedTensorBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
int64_t EnqueueTPUEmbeddingSparseBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
int64_t EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceId() {
return device_ordinal();
}
//===----------------------------------------------------------------------===//
// EnsureShapeOp
//===----------------------------------------------------------------------===//

View File

@ -2171,12 +2171,7 @@ SummaryWriterOp::GetResourceHandleValueAndIdList(
void TPUExecuteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.reserve(args().size() + 2);
// There may be some TPU Embedding ops in the computation, so this effect is
// added conservatively.
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUEmbedding::get());
effects.reserve(args().size() + 1);
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUCompileExecute::get());
@ -2239,12 +2234,7 @@ static LogicalResult Verify(TPUExecuteAndUpdateVariablesOp op) {
void TPUExecuteAndUpdateVariablesOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.reserve(device_var_reads_indices().size() + 2);
// There may be some TPU Embedding ops in the computation, so this effect is
// added conservatively.
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUEmbedding::get());
effects.reserve(device_var_reads_indices().size() + 1);
effects.emplace_back(MemoryEffects::Write::get(),
ResourceEffects::TPUCompileExecute::get());
auto resource_handles = llvm::make_filter_range(args(), [](Value value) {

View File

@ -77,6 +77,10 @@ struct TPUCompileExecute
StringRef getName() final { return "<TPUCompileExecute>"; }
};
struct MustExecute : public ::mlir::SideEffects::Resource::Base<MustExecute> {
StringRef getName() final { return "<MustExecute>"; }
};
} // namespace ResourceEffects
} // namespace TF
} // namespace mlir

View File

@ -185,3 +185,22 @@ func @main() attributes {tf.entry_function = {control_outputs = "", inputs = "",
}
return
}
// -----
// Check that an op with must-execute effect is not pruned, even if it is
// unreachable.
func @must_execute_op() -> () {
// CHECK: tf_executor.graph
// CHECK: tf_executor.island
// CHECK: tf._InternalTestMustExecuteTrait_
tf_executor.graph {
%1 = tf_executor.island {
"tf._InternalTestMustExecuteTrait_"() : () -> ()
tf_executor.yield
}
tf_executor.fetch
}
return
}

View File

@ -1494,9 +1494,9 @@ func @side_effecting_ops_with_different_resources_and_allocations(
// -----
// Tests that we treat different op instances with `TPUEmbeddingSideEffect` as
// independent.
func @embedding_effect_ops(
// Tests that we create a dependency for op instances with
// `TPUEmbeddingSideEffect` with same device ordinal.
func @embedding_effect_same_device(
// expected-remark@above {{ID: 7}}
%arg0: tensor<!tf_type.string>) {
tf_executor.graph {
@ -1504,10 +1504,42 @@ func @embedding_effect_ops(
%island = tf_executor.island {
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Successors: {4}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {1}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {2}}}
tf_executor.yield
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Predecessors: {1}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Predecessors: {3}}}
}
return
// expected-remark@above {{ID: 6}}
// expected-remark@above {{Sinks: {5}}}
}
// -----
// Tests that we treat different op instances with `TPUEmbeddingSideEffect` as
// independent if they have different device ordinals.
func @embedding_effect_different_devices(
// expected-remark@above {{ID: 7}}
%arg0: tensor<!tf_type.string>) {
tf_executor.graph {
// expected-remark@above {{ID: 5}}
%island = tf_executor.island {
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Successors: {4}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {2}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0) {table_ids = [1, 2], device_ordinal = 2} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Successors: {2}}}
tf_executor.yield
@ -1561,6 +1593,42 @@ func @mixed_embedding_and_unknown_effects(
// -----
// Tests that we don't create dependencies between ops `EnqueueTPUEmbedding`
// ops and other embedding ops that don't have a device ordinal.
func @mixed_embedding_and_unknown_effects(
// expected-remark@above {{ID: 8}}
%arg0: tensor<!tf_type.string>,
%arg1: tensor<8xf32>,
%arg2: tensor<8xf32>) {
tf_executor.graph {
// expected-remark@above {{ID: 6}}
%island = tf_executor.island {
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Successors: {5}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {3}}}
"tf.LoadTPUEmbeddingAdagradParameters"(%arg1, %arg2) {config = "", num_shards = 1 : i64, shard_id = 0 : i64, table_id = -1 : i64, table_name = "table1"} : (tensor<8xf32>, tensor<8xf32>) -> ()
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Successors: {3}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 2} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Successors: {3}}}
tf_executor.yield
// expected-remark@above {{ID: 3}}
// expected-remark@above {{Predecessors: {0,1,2}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 5}}
// expected-remark@above {{Predecessors: {4}}}
}
return
// expected-remark@above {{ID: 7}}
// expected-remark@above {{Sinks: {6}}}
}
// -----
// Tests that we create a dependency between two ops with the same op-based
// write effect.
func @same_op_based_write_effect(
@ -1602,13 +1670,13 @@ func @different_op_based_side_effects(
%island = tf_executor.island {
// expected-remark@above {{ID: 4}}
// expected-remark@above {{Successors: {5}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 1} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {3}}}
%0 = "tf.GeneratorDataset"(%arg0, %arg0, %arg0) {device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", finalize_func = @__func_a, init_func = @__func_b, next_func = @__func_c, next_func.experimental_ints_on_device = true, operand_segment_sizes = dense<[1, 1, 1]> : vector<3xi32>, output_shapes = [#tf_type.shape<>], output_types = [!tf_type.string], metadata = ""} : (tensor<!tf_type.string>, tensor<!tf_type.string>, tensor<!tf_type.string>) -> tensor<!tf_type.variant>
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Successors: {3}}}
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2]} : (tensor<!tf_type.string>) -> ()
"tf.EnqueueTPUEmbeddingRaggedTensorBatch"(%arg0){table_ids = [1, 2], device_ordinal = 5} : (tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Successors: {3}}}
tf_executor.yield
@ -1700,10 +1768,7 @@ func @send_recv_effect(
// -----
// Tests that we create a dependency between ops with
// `TF_TPUCompileExecuteSideEffect`. Note that this test also shows a case where
// we could improve pruning of control dependencies (see b/201013649): The
// dependency between the first `tf.TPUExecute` and the `tf_executor.yield` is
// redundant.
// `TF_TPUCompileExecuteSideEffect`.
func @tpu_compile_execute_effect(
// expected-remark@above {{ID: 7}}
%arg0: tensor<!tf_type.string>,
@ -1715,14 +1780,14 @@ func @tpu_compile_execute_effect(
// expected-remark@above {{Successors: {4}}}
"tf.TPUExecute"(%arg0, %arg0) : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 0}}
// expected-remark@above {{Successors: {1,2}}}
// expected-remark@above {{Successors: {1}}}
"tf.TPUExecute"(%arg1, %arg1) : (tensor<!tf_type.string>, tensor<!tf_type.string>) -> ()
// expected-remark@above {{ID: 1}}
// expected-remark@above {{Predecessors: {0}}}
// expected-remark@above {{Successors: {2}}}
tf_executor.yield
// expected-remark@above {{ID: 2}}
// expected-remark@above {{Predecessors: {0,1}}}
// expected-remark@above {{Predecessors: {1}}}
}
tf_executor.fetch %island : !tf_executor.control
// expected-remark@above {{ID: 4}}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
@ -118,10 +119,21 @@ void GraphPruningPass::runOnFunction() {
getFunction().walk([this](tf_executor::GraphOp graph) { PruneGraph(graph); });
}
// An op should be preserved if its identifier is contained in
// `ops_to_preserve_ids_`.
// An op should be preserved if either its identifier is contained in
// `ops_to_preserve_ids_` or if it has a `MustExecute` effect.
bool GraphPruningPass::ShouldPreserveOp(Operation* op) {
return ops_to_preserve_ids_.contains(op->getName().getIdentifier());
if (ops_to_preserve_ids_.contains(op->getName().getIdentifier())) return true;
llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
auto interface = dyn_cast<MemoryEffectOpInterface>(op);
if (interface) interface.getEffects(effects);
for (const auto& effect : effects) {
if (llvm::isa<TF::ResourceEffects::MustExecute>(effect.getResource())) {
return true;
}
}
return false;
}
// An island should be preserved if any of its inner ops should be preserved.