pytorch/caffe2/python/dataio_test.py
Xianjie Chen d0621a2449 NextScopedBlob with well-defined behavior and respect namescope
Summary:
Remove the use of `NextName` in layer model helper, so that the same function return `model_helper` that should construct identical `Net`, when under the same NameScope.

The `NextScopedBlob` should only take effect when there is real name conflicting, otherwise it returns ScopedBlobReference.

This is critical for parameter blobs. In long run, we need to be able to specify parameter blobs more explicitly. (kennyhorror is working on this). This solution works in short term for e.g., two tower sparse nn models.

Reviewed By: kennyhorror

Differential Revision: D4555423

fbshipit-source-id: 2c4b99a61392e5d51aa878f7346466a8f14be187
2017-02-16 17:16:36 -08:00

56 lines
2.1 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python.dataio import ReaderWithLimit
from caffe2.python.dataset import Dataset
from caffe2.python.pipeline import pipe
from caffe2.python.schema import Struct, NewRecord, FeedRecord
from caffe2.python.session import LocalSession
from caffe2.python.task import TaskGroup
from caffe2.python.test_util import TestCase
from caffe2.python import core, workspace
import numpy as np
class TestReaderWithLimit(TestCase):
def test_reader_with_limit(self):
ws = workspace.C.Workspace()
session = LocalSession(ws)
""" 1. feed full dataset """
src_init = core.Net('src_init')
with core.NameScope('src'):
src_values = Struct(('label', np.array(range(100))))
src_blobs = NewRecord(src_init, src_values)
src_ds = Dataset(src_blobs)
FeedRecord(src_blobs, src_values, ws)
ws.run(src_init)
""" 2. Read with limit smaller than size of dataset """
dst_init = core.Net('dst_init')
with core.NameScope('dst'):
dst_ds = Dataset(src_values.clone_schema())
dst_ds.init_empty(dst_init)
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=10)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertFalse(ws.blobs[str(reader.data_finished())].fetch())
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(10))
""" 3. Read with limit larger than size of dataset """
ws.run(dst_init)
with TaskGroup() as tg:
reader = ReaderWithLimit(src_ds.reader(), num_iter=110)
pipe(reader, dst_ds.writer(), num_threads=8)
session.run(tg)
self.assertEquals(
sorted(ws.blobs[str(dst_ds.content().label())].fetch()), range(100))
self.assertTrue(ws.blobs[str(reader.data_finished())].fetch())