mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[tf.contrib.data] Re-implement IteratorGetNext as an AsyncOpKernel.
This prevents the op from consuming an inter-op thread pool thread when blocked, and fixes a potential deadlock when many IteratorGetNext ops are blocked. Fixes #10369. PiperOrigin-RevId: 157878885
This commit is contained in:
parent
9e25c68ad1
commit
8939b85620
|
|
@ -150,7 +150,8 @@ class MapDatasetTest(test.TestCase):
|
||||||
results.append(sess.run(get_next))
|
results.append(sess.run(get_next))
|
||||||
except errors.OutOfRangeError:
|
except errors.OutOfRangeError:
|
||||||
return
|
return
|
||||||
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
|
threads = [self.checkedThread(target=iterator_thread)
|
||||||
|
for _ in range(64)]
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.start()
|
t.start()
|
||||||
for t in threads:
|
for t in threads:
|
||||||
|
|
|
||||||
|
|
@ -5387,6 +5387,7 @@ tf_kernel_library(
|
||||||
srcs = ["iterator_ops.cc"],
|
srcs = ["iterator_ops.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":dataset",
|
":dataset",
|
||||||
|
":ops_util",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:dataset_ops_op_lib",
|
"//tensorflow/core:dataset_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,10 @@ limitations under the License.
|
||||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/resource_op_kernel.h"
|
#include "tensorflow/core/framework/resource_op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
#include "tensorflow/core/lib/random/random.h"
|
#include "tensorflow/core/lib/random/random.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
|
@ -282,17 +285,25 @@ class OneShotIteratorOp : public OpKernel {
|
||||||
IteratorResource* iterator_resource_ = nullptr;
|
IteratorResource* iterator_resource_ = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IteratorGetNextOp : public OpKernel {
|
class IteratorGetNextOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
|
||||||
|
: AsyncOpKernel(ctx),
|
||||||
|
thread_pool_(new thread::ThreadPool(
|
||||||
|
ctx->env(), ThreadOptions(),
|
||||||
|
strings::StrCat("iterator_get_next_thread_",
|
||||||
|
SanitizeThreadSuffix(def().name())),
|
||||||
|
1 /* num_threads */, false /* low_latency_hint */)) {}
|
||||||
|
|
||||||
// TODO(mrry): Convert this to an async op, because
|
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
|
||||||
// `iterator->GetNext()` could trigger long-running operations
|
|
||||||
// (e.g. a QueueDequeue or a remote read).
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
|
||||||
IteratorResource* iterator;
|
IteratorResource* iterator;
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK(ctx,
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
|
LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
|
||||||
|
|
||||||
|
// The call to `iterator->GetNext()` may block and depend on an
|
||||||
|
// inter-op thread pool thread, so we issue the call from the
|
||||||
|
// owned thread pool.
|
||||||
|
thread_pool_->Schedule([this, ctx, iterator, done]() {
|
||||||
core::ScopedUnref unref_iterator(iterator);
|
core::ScopedUnref unref_iterator(iterator);
|
||||||
|
|
||||||
std::vector<Tensor> components;
|
std::vector<Tensor> components;
|
||||||
|
|
@ -305,15 +316,23 @@ class IteratorGetNextOp : public OpKernel {
|
||||||
params.runner = *(ctx->runner());
|
params.runner = *(ctx->runner());
|
||||||
IteratorContext iter_ctx(std::move(params));
|
IteratorContext iter_ctx(std::move(params));
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx,
|
OP_REQUIRES_OK_ASYNC(
|
||||||
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
|
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
|
||||||
OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
|
||||||
|
errors::OutOfRange("End of sequence"), done);
|
||||||
|
|
||||||
for (int i = 0; i < components.size(); ++i) {
|
for (int i = 0; i < components.size(); ++i) {
|
||||||
// TODO(mrry): Check that the shapes match the shape attrs.
|
// TODO(mrry): Check that the shapes match the shape attrs.
|
||||||
ctx->set_output(i, components[i]);
|
ctx->set_output(i, components[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
done();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class IteratorDisposeOp : public OpKernel {
|
class IteratorDisposeOp : public OpKernel {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user