[tf.contrib.data] Re-implement IteratorGetNext as an AsyncOpKernel.

This prevents the op from consuming an inter-op thread pool thread
when blocked, and fixes a potential deadlock when many IteratorGetNext
ops are blocked. Fixes #10369.

PiperOrigin-RevId: 157878885
This commit is contained in:
Derek Murray 2017-06-02 14:50:27 -07:00 committed by TensorFlower Gardener
parent 9e25c68ad1
commit 8939b85620
3 changed files with 44 additions and 23 deletions

View File

@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
results.append(sess.run(get_next))
except errors.OutOfRangeError:
return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
threads = [self.checkedThread(target=iterator_thread)
for _ in range(64)]
for t in threads:
t.start()
for t in threads:

View File

@ -5387,6 +5387,7 @@ tf_kernel_library(
srcs = ["iterator_ops.cc"],
deps = [
":dataset",
":ops_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",

View File

@ -18,7 +18,10 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@ -282,38 +285,54 @@ class OneShotIteratorOp : public OpKernel {
IteratorResource* iterator_resource_ = nullptr;
};
class IteratorGetNextOp : public OpKernel {
class IteratorGetNextOp : public AsyncOpKernel {
public:
explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool(
ctx->env(), ThreadOptions(),
strings::StrCat("iterator_get_next_thread_",
SanitizeThreadSuffix(def().name())),
1 /* num_threads */, false /* low_latency_hint */)) {}
// TODO(mrry): Convert this to an async op, because
// `iterator->GetNext()` could trigger long-running operations
// (e.g. a QueueDequeue or a remote read).
void Compute(OpKernelContext* ctx) override {
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
core::ScopedUnref unref_iterator(iterator);
std::vector<Tensor> components;
bool end_of_sequence;
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
thread_pool_->Schedule([this, ctx, iterator, done]() {
core::ScopedUnref unref_iterator(iterator);
IteratorContext::Params params;
params.env = ctx->env();
params.step_id = ctx->step_id();
params.resource_manager = ctx->resource_manager();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
std::vector<Tensor> components;
bool end_of_sequence;
OP_REQUIRES_OK(ctx,
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
IteratorContext::Params params;
params.env = ctx->env();
params.step_id = ctx->step_id();
params.resource_manager = ctx->resource_manager();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
OP_REQUIRES_OK_ASYNC(
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
errors::OutOfRange("End of sequence"), done);
for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
done();
});
}
private:
std::unique_ptr<thread::ThreadPool> thread_pool_;
};
class IteratorDisposeOp : public OpKernel {