[tf.data][cherrypick] Fix snapshot segfault when using repeat and prefecth

Similar to #49121 on `r2.4`. Needed manual cherrypick due to refactoring
after `r2.5` branch cut
This commit is contained in:
Mihai Maruseac 2021-05-11 17:58:31 -07:00
parent 7dee4eb1c0
commit a13c0ade86
2 changed files with 19 additions and 8 deletions

View File

@ -222,7 +222,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_);
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
TF_GUARDED_BY(mu_);
@ -468,7 +468,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
bool* end_of_sequence) {
mutex_lock l(mu_);
if (iterator_ == nullptr) {
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
Status s = InitializeIterator(ctx, nullptr);
if (!s.ok()) {
iterator_.reset();
return s;
}
}
index_++;
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
@ -547,8 +551,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
int64 start_index)
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
IteratorContext* ctx) {
mutex_lock l(mu_);
@ -597,10 +599,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
}
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
// We need to take a reference here as we will use the input_ and
// its iterator.
input_->Ref();
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
}

View File

@ -413,6 +413,19 @@ class SnapshotTest(tf_record_test_base.TFRecordTestBase,
num_runs_per_fingerprint=1,
num_snapshot_shards_per_run=multiprocessing.cpu_count())
@combinations.generate(test_base.default_test_combinations())
def testRepeatAndPrefetch(self):
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
dataset = dataset.shuffle(buffer_size=16)
dataset = dataset.batch(16)
dataset = dataset.repeat()
dataset = dataset.prefetch(1)
next_element = self.getNext(dataset)
for _ in range(30):
self.evaluate(next_element())
class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
parameterized.TestCase):