pytorch/caffe2/python/dataio_test.py
Artem Volkhin 4102a79da4 Explicitly set should_stop_blob to False in pipeline init
Summary: This diff fixes an issue with running the same reader in the same workspace multiple times. In order to achieve correct behavior of execution step we have to explicitly initialize should_stop_blob with False.

Reviewed By: kennyhorror

Differential Revision: D5224410

fbshipit-source-id: 4ad2740e187b62b0a1f5612ea3eef223dcc8a799
2017-06-11 02:33:42 -07:00

79 lines
2.8 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.dataio import ReaderWithLimit
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.schema import Struct, NewRecord, FeedRecord
from caffe2.python.session import LocalSession
from caffe2.python.task import TaskGroup
from caffe2.python.test_util import TestCase
from caffe2.python import core, workspace
import numpy as np
class TestReaderWithLimit(TestCase):
def test_reader_with_limit(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
""" 1. feed full dataset """
src_init = core.Net('src_init')
with core.NameScope('src'):
src_values = Struct(('label', np.array(list(range(100)))))
src_blobs = NewRecord(src_init, src_values)
src_ds = Dataset(src_blobs)
FeedRecord(src_blobs, src_values, ws)
ws.run(src_init)
""" 2. Read with limit smaller than size of dataset """
dst_init = core.Net('dst_init')
with core.NameScope('dst'):
dst_ds = Dataset(src_values.clone_schema())
dst_ds.init_empty(dst_init)
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
list(range(10))
)
""" 3. Read with limit larger than size of dataset """
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
list(range(100))
)
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
""" 4. Read without counter """
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=None)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
list(range(100))
)
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
""" 5. Read using the same reader without resetting workspace """
session.run(tg)
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
sorted(list(range(100)) * 2)
)