Merge pull request #49124 from tensorflow/mm-cherrypick-tf-data-segfault-fix-to-r2.5

[tf.data][cherrypick] Fix snapshot segfault when using repeat and pre…
This commit is contained in:
Mihai Maruseac 2021-05-12 06:26:41 -07:00 committed by GitHub
commit a4dfb8d1a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 10 deletions

View File

@ -201,8 +201,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
explicit Reader(const Params& params, int64 start_index); explicit Reader(const Params& params, int64 start_index);
~Reader() override;
Status Initialize(IteratorContext* ctx) override; Status Initialize(IteratorContext* ctx) override;
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
@ -222,7 +220,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_); DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_ std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_); TF_GUARDED_BY(mu_);
@ -468,7 +466,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
bool* end_of_sequence) { bool* end_of_sequence) {
mutex_lock l(mu_); mutex_lock l(mu_);
if (iterator_ == nullptr) { if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr)); Status s = InitializeIterator(ctx, nullptr);
if (!s.ok()) {
iterator_.reset();
return s;
}
} }
index_++; index_++;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence); return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
@ -547,8 +549,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
int64 start_index) int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {} : DatasetIterator<Dataset>(params), start_index_(start_index) {}
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
IteratorContext* ctx) { IteratorContext* ctx) {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -597,10 +597,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
} }
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();
return input_->MakeIterator(ctx, this, prefix(), &input_impl_); return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
} }

View File

@ -413,6 +413,19 @@ class SnapshotTest(tf_record_test_base.TFRecordTestBase,
num_runs_per_fingerprint=1, num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count()) num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())
class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase, class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase): parameterized.TestCase):