mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Before we didn't propagate the 'out-of-data' signal if splits_per_epoch wasn't specified. Right now it's a hacky fix (just reuse ReaderWithLimit). azzolini - any suggestions of more elegant solution? I can create an extra reader that just export "is empty" signal out. Overall, I guess we need to turn global_queue into a more sustainable unittest that verifies all possible combinations - I'm still not sure it's correct :-\ Reviewed By: xianjiec Differential Revision: D4665677 fbshipit-source-id: fe44d10ee82c3383145635e67dea1d9b666e061f
66 lines
2.5 KiB
Python
66 lines
2.5 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python.dataio import ReaderWithLimit
|
|
from caffe2.python.dataset import Dataset
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.schema import Struct, NewRecord, FeedRecord
|
|
from caffe2.python.session import LocalSession
|
|
from caffe2.python.task import TaskGroup
|
|
from caffe2.python.test_util import TestCase
|
|
from caffe2.python import core, workspace
|
|
import numpy as np
|
|
|
|
|
|
class TestReaderWithLimit(TestCase):
|
|
def test_reader_with_limit(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
|
|
""" 1. feed full dataset """
|
|
src_init = core.Net('src_init')
|
|
with core.NameScope('src'):
|
|
src_values = Struct(('label', np.array(range(100))))
|
|
src_blobs = NewRecord(src_init, src_values)
|
|
src_ds = Dataset(src_blobs)
|
|
FeedRecord(src_blobs, src_values, ws)
|
|
ws.run(src_init)
|
|
|
|
""" 2. Read with limit smaller than size of dataset """
|
|
dst_init = core.Net('dst_init')
|
|
with core.NameScope('dst'):
|
|
dst_ds = Dataset(src_values.clone_schema())
|
|
dst_ds.init_empty(dst_init)
|
|
ws.run(dst_init)
|
|
|
|
with TaskGroup() as tg:
|
|
reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
|
|
pipe(reader, dst_ds.writer(), num_threads=8)
|
|
session.run(tg)
|
|
|
|
self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
|
|
self.assertEquals(
|
|
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10))
|
|
|
|
""" 3. Read with limit larger than size of dataset """
|
|
ws.run(dst_init)
|
|
with TaskGroup() as tg:
|
|
reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
|
|
pipe(reader, dst_ds.writer(), num_threads=8)
|
|
session.run(tg)
|
|
self.assertEquals(
|
|
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100))
|
|
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
|
|
|
|
""" 3. Read without counter """
|
|
ws.run(dst_init)
|
|
with TaskGroup() as tg:
|
|
reader = ReaderWithLimit(src_ds.reader(), num_iter=None)
|
|
pipe(reader, dst_ds.writer(), num_threads=8)
|
|
session.run(tg)
|
|
self.assertEquals(
|
|
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100))
|
|
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())
|