pytorch/caffe2/python/data_workers_test.py
Aapo Kyrola 0b52b3c79d Generalize threaded data input via queues + Everstore input
Summary:
Xray sampler (originally by ajtulloch) and prigoyal's resnet trainer use variants of the threaded data input where worker threads put stuff into a python queue that is drained by an enqueuer thread that dumps those batches to a Caffe2 queue, that is then drained by the net's DequeueBlobs operator.

There is a lot of boilerplate, which is also quite complicated.

This diff is an attempt to generalize that general stuff under a new module "data_workers" (name could be improved). Basically you pass it a function that is able to return chunks of data (usually data + labels).

I also created a module 'everstore_data_input' which generalizes everstore-origin data input with preprocessing function (image augmentation , for example). See how I refactored sampler.py for the usage.

Next we could create fetcher function for Laser data.

Differential Revision: D4297667

fbshipit-source-id: 8d8a863b177784ae13940730a27dc76cd1dd3dac
2016-12-15 12:01:30 -08:00

58 lines
1.6 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import unittest
from caffe2.python import workspace, cnn
from caffe2.python import timeout_guard
import caffe2.python.data_workers as data_workers
def dummy_fetcher(fetcher_id, batch_size):
# Create random amount of values
n = np.random.randint(64) + 1
data = np.zeros((n, 3))
labels = []
for j in range(n):
data[j, :] *= (j + fetcher_id)
labels.append(data[j, 0])
return [np.array(data), np.array(labels)]
class DataWorkersTest(unittest.TestCase):
def testNonParallelModel(self):
model = cnn.CNNModelHelper(name="test")
coordinator = data_workers.init_data_input_workers(
model,
["data", "label"],
dummy_fetcher,
32,
2,
)
coordinator.start()
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net)
for i in range(500):
with timeout_guard.CompleteInTimeOrDie(5):
workspace.RunNet(model.net.Proto().name)
data = workspace.FetchBlob("data")
labels = workspace.FetchBlob("label")
self.assertEqual(data.shape[0], labels.shape[0])
self.assertEqual(data.shape[0], 32)
for j in range(32):
self.assertEqual(labels[j], data[j, 0])
self.assertEqual(labels[j], data[j, 1])
self.assertEqual(labels[j], data[j, 2])
coordinator.stop()