mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
[tf.contrib.data] Re-implement IteratorGetNext as an AsyncOpKernel.
This prevents the op from consuming an inter-op thread pool thread when blocked, and fixes a potential deadlock when many IteratorGetNext ops are blocked. Fixes #10369. PiperOrigin-RevId: 157878885
This commit is contained in:
parent
9e25c68ad1
commit
8939b85620
|
|
@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
|
|||
results.append(sess.run(get_next))
|
||||
except errors.OutOfRangeError:
|
||||
return
|
||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
||||
threads = [self.checkedThread(target=iterator_thread)
|
||||
for _ in range(64)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
|
|
|
|||
|
|
@ -5387,6 +5387,7 @@ tf_kernel_library(
|
|||
srcs = ["iterator_ops.cc"],
|
||||
deps = [
|
||||
":dataset",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
|
|
|
|||
|
|
@ -18,7 +18,10 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
@ -282,38 +285,54 @@ class OneShotIteratorOp : public OpKernel {
|
|||
IteratorResource* iterator_resource_ = nullptr;
|
||||
};
|
||||
|
||||
class IteratorGetNextOp : public OpKernel {
|
||||
class IteratorGetNextOp : public AsyncOpKernel {
|
||||
public:
|
||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
||||
: AsyncOpKernel(ctx),
|
||||
thread_pool_(new thread::ThreadPool(
|
||||
ctx->env(), ThreadOptions(),
|
||||
strings::StrCat("iterator_get_next_thread_",
|
||||
SanitizeThreadSuffix(def().name())),
|
||||
1 /* num_threads */, false /* low_latency_hint */)) {}
|
||||
|
||||
// TODO(mrry): Convert this to an async op, because
|
||||
// `iterator->GetNext()` could trigger long-running operations
|
||||
// (e.g. a QueueDequeue or a remote read).
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||
IteratorResource* iterator;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
|
||||
core::ScopedUnref unref_iterator(iterator);
|
||||
|
||||
std::vector<Tensor> components;
|
||||
bool end_of_sequence;
|
||||
// The call to `iterator->GetNext()` may block and depend on an
|
||||
// inter-op thread pool thread, so we issue the call from the
|
||||
// owned thread pool.
|
||||
thread_pool_->Schedule([this, ctx, iterator, done]() {
|
||||
core::ScopedUnref unref_iterator(iterator);
|
||||
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.step_id = ctx->step_id();
|
||||
params.resource_manager = ctx->resource_manager();
|
||||
params.runner = *(ctx->runner());
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
std::vector<Tensor> components;
|
||||
bool end_of_sequence;
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
||||
OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.step_id = ctx->step_id();
|
||||
params.resource_manager = ctx->resource_manager();
|
||||
params.runner = *(ctx->runner());
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
|
||||
done);
|
||||
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
|
||||
errors::OutOfRange("End of sequence"), done);
|
||||
|
||||
for (int i = 0; i < components.size(); ++i) {
|
||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||
ctx->set_output(i, components[i]);
|
||||
}
|
||||
|
||||
done();
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
};
|
||||
|
||||
class IteratorDisposeOp : public OpKernel {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user