mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[tf.data][cherrypick] Fix snapshot segfault when using repeat and prefecth
Similar to #49121 on `r2.4`. Needed manual cherrypick due to refactoring after `r2.5` branch cut
This commit is contained in:
parent
7dee4eb1c0
commit
a13c0ade86
|
|
@ -222,7 +222,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
|
|||
|
||||
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
|
||||
|
||||
DatasetBase* input_ TF_GUARDED_BY(mu_);
|
||||
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
|
||||
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
|
||||
TF_GUARDED_BY(mu_);
|
||||
|
|
@ -468,7 +468,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
|
|||
bool* end_of_sequence) {
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_ == nullptr) {
|
||||
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
|
||||
Status s = InitializeIterator(ctx, nullptr);
|
||||
if (!s.ok()) {
|
||||
iterator_.reset();
|
||||
return s;
|
||||
}
|
||||
}
|
||||
index_++;
|
||||
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
|
|
@ -547,8 +551,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
|
|||
int64 start_index)
|
||||
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
|
||||
|
||||
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
||||
IteratorContext* ctx) {
|
||||
mutex_lock l(mu_);
|
||||
|
|
@ -597,10 +599,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
|
|||
}
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
|
||||
|
||||
// We need to take a reference here as we will use the input_ and
|
||||
// its iterator.
|
||||
input_->Ref();
|
||||
|
||||
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -413,6 +413,19 @@ class SnapshotTest(tf_record_test_base.TFRecordTestBase,
|
|||
num_runs_per_fingerprint=1,
|
||||
num_snapshot_shards_per_run=multiprocessing.cpu_count())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRepeatAndPrefetch(self):
|
||||
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
|
||||
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
|
||||
dataset = dataset.shuffle(buffer_size=16)
|
||||
dataset = dataset.batch(16)
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.prefetch(1)
|
||||
next_element = self.getNext(dataset)
|
||||
for _ in range(30):
|
||||
self.evaluate(next_element())
|
||||
|
||||
|
||||
class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
|
||||
parameterized.TestCase):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user