bug fix to distringuish train/test data

Summary:
We often use same net for training and testing, but we must distinguish their data. My yestterday's diff forgot to include that distinction (it was in the xray sampler before), and this diff adds it. Basically one provides a name for the input source for data_workers, and all the queues and scratch spaces are suffixed with that to separate them.

Also specify the caffe2 queue's size to 4, which is empirically found to be sufficient. It was errorneously defined to be function of batch size, which does not make sense as each *element* in the queue is a batch, and led to out of memory issues on xray trainer.

Differential Revision: D4329449

fbshipit-source-id: c994da1c8b0935b8eda2402c118d49b76caa7da8
This commit is contained in:
Aapo Kyrola 2016-12-15 09:30:47 -08:00 committed by Bram Wasti
parent cb918ac727
commit e80423f341

View File

@ -13,7 +13,8 @@ Basic usage is as follows:
net,
["data", "label"],
my_fetch_fun,
32
32,
"train"
)
...
coordinator.start()
@ -21,6 +22,10 @@ Basic usage is as follows:
First argument is the Caffe2 net (or model helper), and second argument
is list of input blobs that are to be fed.
Last argument is used to distinguish different sources of data, such as train
or test data. This is to ensure the data does not get mixed up, although the
nets would share blobs.
To do the actual data loading, one defines a "fetcher function"
that has call signature
my_fetch_fun(worker_id, batch_size)
@ -58,7 +63,8 @@ def init_data_input_workers(
input_blob_names,
fetch_fun,
batch_size,
num_worker_threads=2
num_worker_threads=2,
input_source_name="train",
):
global global_coordinator
device_option = scope.CurrentDeviceScope()
@ -72,6 +78,7 @@ def init_data_input_workers(
batch_size,
device_option,
scope.CurrentNameScope(),
input_source_name,
)
# Launch fetch worker threads
@ -92,7 +99,7 @@ def init_data_input_workers(
class DataInputCoordinator(object):
def __init__(self, net, input_blob_names, batch_size,
device_option, namescope):
device_option, namescope, input_source_name):
self._net = net
self._input_blob_names = input_blob_names
self._batch_size = batch_size
@ -102,7 +109,7 @@ class DataInputCoordinator(object):
self._namescope = namescope
self._active = True
self._workers = []
self._input_source_name = input_source_name
self._create_caffe2_queues_and_ops()
def is_active(self):
@ -181,7 +188,8 @@ class DataInputCoordinator(object):
'''
Enqueue the correctly sized batch arrays to Caffe2's queue.
'''
scratch_name = self._namescope + blob_name + "_scratch"
scratch_name = self._namescope + blob_name + \
"_scratch_" + self._input_source_name
blob = core.BlobReference(scratch_name)
workspace.FeedBlob(
blob,
@ -212,8 +220,8 @@ class DataInputCoordinator(object):
return core.ScopedBlobReference(queue_name)
for blob_name in self._input_blob_names:
qname = blob_name + "_c2queue"
q = create_queue(qname, num_blobs=1, capacity=self._batch_size * 2)
qname = blob_name + "_c2queue" + "_" + self._input_source_name
q = create_queue(qname, num_blobs=1, capacity=4)
self._queues.append(q)
log.info("Created queue: {}".format(q))