mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: As title Differential Revision: D5151894 fbshipit-source-id: 7badce5d3122e8f2526a7170fbdcf0d0b66e2638
518 lines
17 KiB
Python
518 lines
17 KiB
Python
## @package data_workers
|
|
# Module caffe2.python.data_workers
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
'''
|
|
This module provides a python-land multithreaded data input mechanism
|
|
for Caffe2 nets.
|
|
|
|
Basic usage is as follows:
|
|
coordinator = data_workers.init_data_input_workers(
|
|
net,
|
|
["data", "label"],
|
|
my_fetch_fun,
|
|
batch_size=32,
|
|
input_source_name="train",
|
|
dont_rebatch=False
|
|
)
|
|
...
|
|
coordinator.start()
|
|
|
|
First argument is the Caffe2 net (or model helper), and second argument
|
|
is list of input blobs that are to be fed.
|
|
|
|
Argument 'input_source_name' is used to distinguish different sources of data,
|
|
such as train or test data. This is to ensure the data does not get mixed up,
|
|
although two nets would share blobs.
|
|
|
|
To do the actual data loading, one defines a "fetcher function"
|
|
that has call signature
|
|
my_fetch_fun(worker_id, batch_size)
|
|
|
|
Optionally, one can define a "init function" that is called once before
|
|
threads start, and has call signature:
|
|
my_init_fun(data_coordinator, global_coordinator)
|
|
|
|
If dont_rebatch is set to True, the data input is not batched into equal sized
|
|
chunks but data directly provided by fetchers is used.
|
|
|
|
'batch_columns' can be used to specify which dimension is the batch dimension,
|
|
for each of the inputs. Default is 0 for all iputs.
|
|
|
|
This function returns a list of numpy arrays corresponding to the different
|
|
input blobs. In the example above, it would return two arrays, one for the
|
|
data blob and another for the labels. These arrays can have arbitrary number
|
|
of elements (i.e they do not need to match the batch size). The batch size
|
|
is provided for the function as a hint only.
|
|
|
|
For example, fetcher function could download images from a remote service or
|
|
load random images from a directory on a file system.
|
|
|
|
For a dummy example, see the data_workers_test unit test.
|
|
|
|
Note that for data_parallel_models, init_data_input_workers will be called
|
|
for each GPU. Note that the 'coordinator' returned by the function is same
|
|
each time.
|
|
'''
|
|
|
|
try:
|
|
import Queue
|
|
except ImportError:
|
|
# Py3
|
|
import queue as Queue
|
|
import logging
|
|
import threading
|
|
import atexit
|
|
import numpy as np
|
|
import time
|
|
import collections
|
|
|
|
from caffe2.python import workspace, core, scope, utils
|
|
from caffe2.proto import caffe2_pb2
|
|
|
|
log = logging.getLogger("data_workers")
|
|
log.setLevel(logging.INFO)
|
|
LOG_INT_SECS = 60
|
|
|
|
|
|
def get_worker_ids(num_workers):
|
|
return list(range(0, num_workers))
|
|
|
|
|
|
def init_data_input_workers(
|
|
net,
|
|
input_blob_names,
|
|
fetch_fun,
|
|
batch_size,
|
|
num_worker_threads=2,
|
|
input_source_name="train",
|
|
max_buffered_batches=800,
|
|
init_fun=None,
|
|
external_loggers=None,
|
|
dont_rebatch=False,
|
|
batch_columns=None,
|
|
):
|
|
global global_coordinator
|
|
device_option = scope.CurrentDeviceScope()
|
|
if (device_option is None):
|
|
device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
|
|
|
|
# Create coordinator object
|
|
coordinator = DataInputCoordinator(
|
|
net,
|
|
input_blob_names,
|
|
batch_size,
|
|
device_option,
|
|
scope.CurrentNameScope(),
|
|
input_source_name,
|
|
global_coordinator.get_queue(input_source_name, max_buffered_batches),
|
|
init_fun=init_fun,
|
|
external_loggers=external_loggers,
|
|
dont_rebatch=dont_rebatch,
|
|
batch_columns=batch_columns,
|
|
)
|
|
|
|
# Launch fetch worker threads
|
|
worker_ids = [
|
|
global_coordinator.get_new_worker_id()
|
|
for i in range(num_worker_threads)
|
|
]
|
|
workers = [
|
|
threading.Thread(
|
|
target=fetcher,
|
|
name="data_workers fetcher id {}".format(worker_id),
|
|
args=[coordinator, worker_id, fetch_fun, batch_size, input_blob_names],
|
|
) for worker_id in worker_ids
|
|
]
|
|
|
|
workers.append(threading.Thread(
|
|
target=enqueuer,
|
|
name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
|
|
args=[coordinator]))
|
|
coordinator._workers = workers
|
|
global_coordinator.add(coordinator)
|
|
|
|
return global_coordinator
|
|
|
|
|
|
class DataInputCoordinator(object):
|
|
def __init__(self, net, input_blob_names, batch_size,
|
|
device_option, namescope, input_source_name, queue,
|
|
init_fun=None, external_loggers=None, dont_rebatch=False,
|
|
batch_columns=None):
|
|
self._counter = 0
|
|
self._input_blob_names = input_blob_names
|
|
self._batch_size = batch_size
|
|
self._internal_queue = queue
|
|
self._scratch_blobs = set()
|
|
self._queues = []
|
|
self._device_option = device_option
|
|
self._namescope = namescope
|
|
self._active = True
|
|
self._started = False
|
|
self._workers = []
|
|
self._input_source_name = input_source_name
|
|
self._c2_queue_capacity = 4
|
|
self._create_caffe2_queues(net)
|
|
self._create_caffe2_ops(net)
|
|
self._inputs = 0
|
|
self._prev_seconds = 0
|
|
self._last_warning = time.time()
|
|
self._init_fun = init_fun
|
|
self._metrics = collections.defaultdict(lambda: 0)
|
|
self._external_loggers = external_loggers
|
|
self._dont_rebatch = dont_rebatch
|
|
if batch_columns is None:
|
|
batch_columns = [0 for _ in input_blob_names]
|
|
self._batch_columns = batch_columns
|
|
|
|
def is_active(self):
|
|
return self._active
|
|
|
|
def init(self, global_coordinator):
|
|
if self._init_fun and not self._started:
|
|
self._init_fun(self, global_coordinator)
|
|
|
|
def _start(self):
|
|
if self._started:
|
|
return
|
|
self._active = True
|
|
self._started = True
|
|
self._inputs = 0
|
|
self._prev_seconds = time.time()
|
|
|
|
for w in self._workers:
|
|
w.daemon = True
|
|
w.start()
|
|
|
|
def _stop(self, reason=None):
|
|
try:
|
|
self._active = False
|
|
if reason is not None:
|
|
log.error("Data input failed due to an error: {}".format(reason))
|
|
|
|
for q in self._queues:
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator("CloseBlobsQueue", [q], [])
|
|
)
|
|
self._started = False
|
|
finally:
|
|
self._log_inputs_per_interval(0, force=True)
|
|
|
|
def _wait_finish(self):
|
|
print("Wait for workers to die: {}".format(self._input_source_name))
|
|
for w in self._workers:
|
|
if w != threading.current_thread():
|
|
w.join(5.0) # don't wait forever, thread may be blocked in i/o
|
|
success = True
|
|
for w in self._workers:
|
|
if w.isAlive():
|
|
print("Worker {} failed to close while waiting".format(w))
|
|
success = False
|
|
|
|
# Release memory for the scratch blobs
|
|
if success and len(self._scratch_blobs) > 0:
|
|
utils.ResetBlobs(self._scratch_blobs)
|
|
|
|
print("All workers terminated: {}".format(success))
|
|
return success
|
|
|
|
def _get(self):
|
|
while self.is_active():
|
|
try:
|
|
return self._internal_queue.get(block=True, timeout=0.5)
|
|
except Queue.Empty:
|
|
continue
|
|
return None
|
|
|
|
def put(self, chunk):
|
|
if len(chunk) == 0:
|
|
print("Worker provided zero length input")
|
|
return
|
|
while self.is_active():
|
|
try:
|
|
qsize = self._internal_queue.qsize()
|
|
if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
|
|
print("Warning, data loading lagging behind: " +
|
|
"name={}".format(qsize, self._input_source_name))
|
|
self._last_warning = time.time()
|
|
self._counter += 1
|
|
self._internal_queue.put(chunk, block=True, timeout=0.5)
|
|
self._log_inputs_per_interval(chunk[0].shape[0])
|
|
return
|
|
except Queue.Full:
|
|
log.debug("Queue full: stalling fetchers...")
|
|
continue
|
|
|
|
def _enqueue_batch_direct(self):
|
|
data = self._get()
|
|
if data is None:
|
|
return
|
|
if self.is_active():
|
|
for b, q, c in zip(self._input_blob_names, self._queues, data):
|
|
self._enqueue(b, q, c)
|
|
|
|
def _enqueue_batch(self):
|
|
'''
|
|
This pulls data from the python-side queue and collects them
|
|
into batch-sized pieces, unless dont_rebatch is set to true.
|
|
'''
|
|
if self._dont_rebatch:
|
|
self._enqueue_batch_direct()
|
|
return
|
|
|
|
cur_batch = [np.array([]) for d in self._input_blob_names]
|
|
first_batch_col = self._batch_columns[0]
|
|
|
|
# Collect data until we have a full batch size
|
|
while (
|
|
cur_batch[0].shape[0] == 0 or
|
|
cur_batch[0].shape[first_batch_col] < self._batch_size
|
|
) and self.is_active():
|
|
chunk = self._get()
|
|
if chunk is None:
|
|
continue
|
|
|
|
for j, chunk_elem in enumerate(chunk):
|
|
if cur_batch[j].shape[0] == 0:
|
|
cur_batch[j] = chunk_elem.copy()
|
|
else:
|
|
cur_batch[j] = np.append(
|
|
cur_batch[j], chunk_elem, axis=self._batch_columns[j]
|
|
)
|
|
|
|
start_time = time.time()
|
|
try:
|
|
# Return data over the batch size back to queue
|
|
if cur_batch[0].shape[0] > 0 and cur_batch[0].shape[
|
|
first_batch_col
|
|
] > self._batch_size:
|
|
leftover = []
|
|
trimmed_batch = []
|
|
for j, b in enumerate(cur_batch):
|
|
[c, l] = np.split(
|
|
b, [self._batch_size], axis=self._batch_columns[j]
|
|
)
|
|
leftover.append(l)
|
|
trimmed_batch.append(c)
|
|
cur_batch = trimmed_batch
|
|
try:
|
|
self._internal_queue.put(leftover, block=False)
|
|
except Queue.Full:
|
|
pass
|
|
|
|
assert cur_batch[0].shape[first_batch_col] == self._batch_size
|
|
|
|
if self.is_active():
|
|
for b, q, c in zip(
|
|
self._input_blob_names, self._queues, cur_batch
|
|
):
|
|
self._enqueue(b, q, c)
|
|
finally:
|
|
self.put_metric('enqueue_time', time.time() - start_time)
|
|
|
|
def _enqueue(self, blob_name, queue, data_arr):
|
|
'''
|
|
Enqueue the correctly sized batch arrays to Caffe2's queue.
|
|
'''
|
|
scratch_name = self._namescope + blob_name + \
|
|
"_scratch_" + self._input_source_name
|
|
blob = core.BlobReference(scratch_name)
|
|
status = core.BlobReference(scratch_name + "_status")
|
|
if blob not in self._scratch_blobs:
|
|
self._scratch_blobs.add(blob)
|
|
self._scratch_blobs.add(status)
|
|
workspace.FeedBlob(
|
|
blob,
|
|
data_arr,
|
|
device_option=self._device_option
|
|
)
|
|
|
|
op = core.CreateOperator(
|
|
"SafeEnqueueBlobs",
|
|
[queue, blob],
|
|
[blob, status],
|
|
device_option=self._device_option
|
|
)
|
|
workspace.RunOperatorOnce(op)
|
|
|
|
def _create_caffe2_queues(self, net):
|
|
'''
|
|
Creates queues on caffe2 side
|
|
'''
|
|
def create_queue(queue_name, num_blobs, capacity):
|
|
workspace.RunOperatorOnce(
|
|
core.CreateOperator(
|
|
"CreateBlobsQueue",
|
|
[], [queue_name],
|
|
num_blobs=1,
|
|
capacity=capacity))
|
|
return core.ScopedBlobReference(queue_name)
|
|
|
|
for blob_name in self._input_blob_names:
|
|
qname = blob_name + "_c2queue" + "_" + self._input_source_name
|
|
q = create_queue(
|
|
qname, num_blobs=1, capacity=self._c2_queue_capacity
|
|
)
|
|
self._queues.append(q)
|
|
|
|
def _create_caffe2_ops(self, net):
|
|
'''
|
|
Creates dequeue-ops on caffe2 side
|
|
'''
|
|
for q, blob_name in zip(self._queues, self._input_blob_names):
|
|
# Add operator to the Caffe2 network to dequeue
|
|
net.DequeueBlobs(q, blob_name)
|
|
|
|
def _log_inputs_per_interval(self, inputs, force=False):
|
|
self._inputs += inputs
|
|
current_seconds = time.time()
|
|
delta_seconds = current_seconds - self._prev_seconds
|
|
if delta_seconds >= LOG_INT_SECS or force:
|
|
inputs_per_sec = int(self._inputs / delta_seconds)
|
|
qsize = self._internal_queue.qsize()
|
|
print("{}/{}: {} inputs/sec".format(
|
|
self._input_source_name,
|
|
self._namescope,
|
|
inputs_per_sec,
|
|
))
|
|
print("-- queue: {} batches".format(qsize))
|
|
# log and reset perf metrics
|
|
self.put_metric('inputs_per_sec', inputs_per_sec, False)
|
|
self.put_metric('queue_size', qsize, False)
|
|
self.put_metric('time_elapsed', delta_seconds, False)
|
|
self._log(self._metrics)
|
|
self._reset_metrics()
|
|
self._inputs = 0
|
|
self._prev_seconds = current_seconds
|
|
|
|
def _log(self, metrics):
|
|
if not self._external_loggers:
|
|
return
|
|
for logger in self._external_loggers:
|
|
try:
|
|
logger.log(metrics)
|
|
except Exception as e:
|
|
print("Failed to call ExternalLogger: {}".format(e))
|
|
|
|
def put_metric(self, key, value, count=True):
|
|
self._metrics[key] += value
|
|
if count:
|
|
count_key = '{}_count'.format(key)
|
|
self._metrics[count_key] += 1
|
|
|
|
def _reset_metrics(self):
|
|
self._metrics = collections.defaultdict(lambda: 0)
|
|
|
|
|
|
class GlobalCoordinator(object):
|
|
def __init__(self):
|
|
self._coordinators = []
|
|
self._fetcher_id_seq = 0
|
|
self._worker_ids = []
|
|
self._queues = {}
|
|
self.register_shutdown_handler()
|
|
|
|
def add(self, coordinator):
|
|
self._coordinators.append(coordinator)
|
|
|
|
def get_new_worker_id(self):
|
|
worker_id = self._fetcher_id_seq
|
|
self._worker_ids.append(worker_id)
|
|
self._fetcher_id_seq += 1
|
|
return worker_id
|
|
|
|
def get_worker_ids(self):
|
|
return self._worker_ids
|
|
|
|
def get_queue(self, queue_name, max_buffered_batches):
|
|
assert isinstance(max_buffered_batches, int)
|
|
if queue_name not in self._queues:
|
|
self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
|
|
return self._queues[queue_name]
|
|
|
|
def start(self):
|
|
for c in self._coordinators:
|
|
c.init(self)
|
|
c._start()
|
|
|
|
def reset_data_input(self, namescope, name, net, batch_size):
|
|
log.info("Reset data input {}, batch size {}: ".format(name, batch_size))
|
|
for c in self._coordinators:
|
|
if c._input_source_name == name and c._namescope == namescope:
|
|
c._batch_size = batch_size
|
|
c._create_caffe2_ops(net)
|
|
|
|
def stop(self):
|
|
all_success = True
|
|
for c in self._coordinators:
|
|
c._stop()
|
|
for c in self._coordinators:
|
|
success = c._wait_finish()
|
|
all_success = all_success and success
|
|
self._coordinators = []
|
|
return all_success
|
|
|
|
def stop_coordinator(self, input_source_name):
|
|
'''
|
|
Stop a specific coordinator
|
|
'''
|
|
for c in self._coordinators:
|
|
if c._input_source_name == input_source_name:
|
|
c._stop()
|
|
c._wait_finish()
|
|
self._coordinators = [
|
|
c for c in self._coordinators
|
|
if c._input_source_name != input_source_name
|
|
]
|
|
|
|
def register_shutdown_handler(self):
|
|
def cleanup():
|
|
self.stop()
|
|
|
|
atexit.register(cleanup)
|
|
|
|
|
|
global_coordinator = GlobalCoordinator()
|
|
|
|
|
|
def fetcher(coordinator, worker_id, fetch_fun, batch_size, input_blob_names):
|
|
while coordinator.is_active():
|
|
start_time = time.time()
|
|
try:
|
|
input_data = fetch_fun(worker_id, batch_size)
|
|
if input_data is None:
|
|
print("Fetcher function returned None")
|
|
continue
|
|
|
|
assert len(input_data) == len(input_blob_names), \
|
|
"Expecting data blob for each input"
|
|
for d in input_data:
|
|
assert isinstance(d, np.ndarray), \
|
|
"Fetcher function must return a numpy array"
|
|
if not coordinator._dont_rebatch:
|
|
j = 1
|
|
for d in input_data[1:]:
|
|
assert d.shape[coordinator._batch_columns[j]] == input_data[0].shape[coordinator._batch_columns[0]], \
|
|
"Each returned input must have equal number of samples"
|
|
j += 1
|
|
|
|
coordinator.put(input_data)
|
|
except Exception as e:
|
|
print(e)
|
|
logging.exception("Exception in fetcher", e)
|
|
coordinator._stop("Exception in fetcher {}: {}".format(
|
|
worker_id, e
|
|
))
|
|
finally:
|
|
coordinator.put_metric('fetcher_time', time.time() - start_time)
|
|
|
|
|
|
def enqueuer(coordinator):
|
|
while coordinator.is_active():
|
|
coordinator._enqueue_batch()
|