diff --git a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc index d7c04962b85..adcdcea9ff5 100644 --- a/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc @@ -201,8 +201,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader explicit Reader(const Params& params, int64 start_index); - ~Reader() override; - Status Initialize(IteratorContext* ctx) override; Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, @@ -222,7 +220,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader std::unique_ptr input_impl_ TF_GUARDED_BY(mu_); - DatasetBase* input_ TF_GUARDED_BY(mu_); + DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr; std::unique_ptr instantiated_reader_func_ TF_GUARDED_BY(mu_); @@ -468,7 +466,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal( bool* end_of_sequence) { mutex_lock l(mu_); if (iterator_ == nullptr) { - TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr)); + Status s = InitializeIterator(ctx, nullptr); + if (!s.ok()) { + iterator_.reset(); + return s; + } } index_++; return iterator_->GetNext(ctx, out_tensors, end_of_sequence); @@ -547,8 +549,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params, int64 start_index) : DatasetIterator(params), start_index_(start_index) {} -SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); } - Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( IteratorContext* ctx) { mutex_lock l(mu_); @@ -597,10 +597,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize( } TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_)); - // We need to take a reference here as we will use the input_ and - // its iterator. - input_->Ref(); - return input_->MakeIterator(ctx, this, prefix(), &input_impl_); } diff --git a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py index fe6db1eb860..f720966df99 100644 --- a/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/snapshot_test.py @@ -413,6 +413,19 @@ class SnapshotTest(tf_record_test_base.TFRecordTestBase, num_runs_per_fingerprint=1, num_snapshot_shards_per_run=multiprocessing.cpu_count()) + @combinations.generate(test_base.default_test_combinations()) + def testRepeatAndPrefetch(self): + """This test reproduces github.com/tensorflow/tensorflow/issues/48903.""" + dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32)) + dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir)) + dataset = dataset.shuffle(buffer_size=16) + dataset = dataset.batch(16) + dataset = dataset.repeat() + dataset = dataset.prefetch(1) + next_element = self.getNext(dataset) + for _ in range(30): + self.evaluate(next_element()) + class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase, parameterized.TestCase):