mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
bug fix to distringuish train/test data
Summary: We often use same net for training and testing, but we must distinguish their data. My yestterday's diff forgot to include that distinction (it was in the xray sampler before), and this diff adds it. Basically one provides a name for the input source for data_workers, and all the queues and scratch spaces are suffixed with that to separate them. Also specify the caffe2 queue's size to 4, which is empirically found to be sufficient. It was errorneously defined to be function of batch size, which does not make sense as each *element* in the queue is a batch, and led to out of memory issues on xray trainer. Differential Revision: D4329449 fbshipit-source-id: c994da1c8b0935b8eda2402c118d49b76caa7da8
This commit is contained in:
parent
cb918ac727
commit
e80423f341
|
|
@ -13,7 +13,8 @@ Basic usage is as follows:
|
|||
net,
|
||||
["data", "label"],
|
||||
my_fetch_fun,
|
||||
32
|
||||
32,
|
||||
"train"
|
||||
)
|
||||
...
|
||||
coordinator.start()
|
||||
|
|
@ -21,6 +22,10 @@ Basic usage is as follows:
|
|||
First argument is the Caffe2 net (or model helper), and second argument
|
||||
is list of input blobs that are to be fed.
|
||||
|
||||
Last argument is used to distinguish different sources of data, such as train
|
||||
or test data. This is to ensure the data does not get mixed up, although the
|
||||
nets would share blobs.
|
||||
|
||||
To do the actual data loading, one defines a "fetcher function"
|
||||
that has call signature
|
||||
my_fetch_fun(worker_id, batch_size)
|
||||
|
|
@ -58,7 +63,8 @@ def init_data_input_workers(
|
|||
input_blob_names,
|
||||
fetch_fun,
|
||||
batch_size,
|
||||
num_worker_threads=2
|
||||
num_worker_threads=2,
|
||||
input_source_name="train",
|
||||
):
|
||||
global global_coordinator
|
||||
device_option = scope.CurrentDeviceScope()
|
||||
|
|
@ -72,6 +78,7 @@ def init_data_input_workers(
|
|||
batch_size,
|
||||
device_option,
|
||||
scope.CurrentNameScope(),
|
||||
input_source_name,
|
||||
)
|
||||
|
||||
# Launch fetch worker threads
|
||||
|
|
@ -92,7 +99,7 @@ def init_data_input_workers(
|
|||
|
||||
class DataInputCoordinator(object):
|
||||
def __init__(self, net, input_blob_names, batch_size,
|
||||
device_option, namescope):
|
||||
device_option, namescope, input_source_name):
|
||||
self._net = net
|
||||
self._input_blob_names = input_blob_names
|
||||
self._batch_size = batch_size
|
||||
|
|
@ -102,7 +109,7 @@ class DataInputCoordinator(object):
|
|||
self._namescope = namescope
|
||||
self._active = True
|
||||
self._workers = []
|
||||
|
||||
self._input_source_name = input_source_name
|
||||
self._create_caffe2_queues_and_ops()
|
||||
|
||||
def is_active(self):
|
||||
|
|
@ -181,7 +188,8 @@ class DataInputCoordinator(object):
|
|||
'''
|
||||
Enqueue the correctly sized batch arrays to Caffe2's queue.
|
||||
'''
|
||||
scratch_name = self._namescope + blob_name + "_scratch"
|
||||
scratch_name = self._namescope + blob_name + \
|
||||
"_scratch_" + self._input_source_name
|
||||
blob = core.BlobReference(scratch_name)
|
||||
workspace.FeedBlob(
|
||||
blob,
|
||||
|
|
@ -212,8 +220,8 @@ class DataInputCoordinator(object):
|
|||
return core.ScopedBlobReference(queue_name)
|
||||
|
||||
for blob_name in self._input_blob_names:
|
||||
qname = blob_name + "_c2queue"
|
||||
q = create_queue(qname, num_blobs=1, capacity=self._batch_size * 2)
|
||||
qname = blob_name + "_c2queue" + "_" + self._input_source_name
|
||||
q = create_queue(qname, num_blobs=1, capacity=4)
|
||||
self._queues.append(q)
|
||||
log.info("Created queue: {}".format(q))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user