mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: For customers like Ads, Feeds, MarketPlace, their training data size is super large. It is unnecessary and costly to go over all the data to compute meta information. In this diff, numSample option is added in preCompute, so users have control over how many samples they want to use when computing meta information. Differential Revision: D4492399 fbshipit-source-id: 7199381d226ee6300a959fc5e116d39984d199fc
337 lines
12 KiB
Python
337 lines
12 KiB
Python
"""
|
|
Defines the base interface for reading and writing operations.
|
|
|
|
Readers/Writers are objects that produce operations that read/write sequences
|
|
of data. Each operation reads or writes a list of BlobReferences.
|
|
|
|
Readers and Writers must be implemented such that read and write operations
|
|
are atomic and thread safe.
|
|
|
|
Examples of possible Readers and Writers:
|
|
HiveReader, HiveWriter,
|
|
QueueReader, QueueWriter,
|
|
DatasetReader, DatasetWriter,
|
|
DBReader, DBWriter,
|
|
|
|
See `dataset.py` for an example of implementation.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
from caffe2.python import core
|
|
from caffe2.python.schema import Field, Struct, from_blob_list
|
|
import numpy as np
|
|
|
|
|
|
class Reader(object):
|
|
def __init__(self, schema=None):
|
|
if schema is not None:
|
|
assert isinstance(schema, Field)
|
|
self._schema = schema
|
|
|
|
def schema(self):
|
|
"""
|
|
Return the schema associated with the Hive Reader
|
|
"""
|
|
assert self._schema is not None, 'Schema not provided for this reader.'
|
|
return self._schema
|
|
|
|
def _set_schema(self, schema):
|
|
self._schema = schema
|
|
|
|
def setup_ex(self, init_net, finish_net):
|
|
"""Nets to be executed once at startup and finish.
|
|
Experimental extension. Don't use yet"""
|
|
pass
|
|
|
|
def read_ex(self, local_init_net, local_finish_net):
|
|
"""Experimental extension to the interface. Don't use yet"""
|
|
read_net = core.Net('reader_body')
|
|
return ([read_net], ) + self.read(read_net)
|
|
|
|
def read_record_ex(self, local_init_net, local_finish_net):
|
|
"""Experimental extension to the interface. Don't use yet"""
|
|
nets, should_stop, fields = self.read_ex(
|
|
local_init_net, local_finish_net)
|
|
if self._schema:
|
|
fields = from_blob_list(self._schema, fields)
|
|
return nets, should_stop, fields
|
|
|
|
"""
|
|
Reader is a abstract class to be implemented in order to provide
|
|
operations capable of iterating through a dataset or stream of data.
|
|
|
|
A Reader must implement at least one operation, `read`, which
|
|
adds operations to a net that read the next batch of data. Readers can
|
|
optionally support the `reset` operation, which is useful when multiple
|
|
passes over the data are required.
|
|
"""
|
|
def read(self, read_net):
|
|
"""
|
|
Add operations to read_net that will read the read batch of data
|
|
and return a list of BlobReference representing the blobs that will
|
|
contain the batches produced.
|
|
|
|
Operations added to `read_net` must be thread safe and atomic, that is,
|
|
it should be possible to clone `read_net` and run multiple instances of
|
|
it in parallel.
|
|
|
|
Args:
|
|
read_net: the net that will be appended with read operations
|
|
|
|
Returns:
|
|
A tuple (should_stop, fields), with:
|
|
|
|
should_stop: BlobReference pointing to a boolean scalar
|
|
blob that indicates whether the read operation
|
|
was succesfull or whether the end of data has
|
|
been reached.
|
|
fields: A tuple of BlobReference containing the latest batch
|
|
of data that was read.
|
|
"""
|
|
raise NotImplementedError('Readers must implement `read`.')
|
|
|
|
def reset(self, net):
|
|
"""Append operations to `net` that will reset the reader.
|
|
|
|
This can be used to read the data multiple times.
|
|
Not all readers support this operation.
|
|
"""
|
|
raise NotImplementedError('This reader cannot be resetted.')
|
|
|
|
def read_record(self, read_net):
|
|
should_stop, fields = self.read(read_net)
|
|
if self._schema:
|
|
fields = from_blob_list(self._schema, fields)
|
|
return should_stop, fields
|
|
|
|
def execution_step(self, reader_net_name=None, external_should_stop=None):
|
|
"""Create an execution step with a net containing read operators.
|
|
|
|
The execution step will contain a `stop_blob` that knows how to stop
|
|
the execution loop when end of data was reached.
|
|
|
|
E.g.:
|
|
|
|
read_step, fields = reader.execution_step()
|
|
consume_net = core.Net('consume')
|
|
consume_net.Print(fields[0], [])
|
|
p = core.Plan('reader')
|
|
p.AddStep(read_step.AddNet(consume_net))
|
|
core.RunPlan(p)
|
|
|
|
Args:
|
|
|
|
reader_net_name: (optional) the name of the reader_net to be
|
|
created. The execution step will
|
|
be named accordingly.
|
|
|
|
Returns:
|
|
A tuple (read_step, fields), with:
|
|
|
|
read_step: A newly created execution step containing a net with
|
|
read operations. The step will have `stop_blob` set,
|
|
in order to stop the loop on end of data.
|
|
fields: A tuple of BlobReference containing the latest batch
|
|
of data that was read.
|
|
"""
|
|
reader_net = core.Net(reader_net_name or 'reader')
|
|
should_stop, fields = self.read_record(reader_net)
|
|
if external_should_stop is not None:
|
|
should_stop = reader_net.Or([external_should_stop, should_stop])
|
|
read_step = core.execution_step(
|
|
'{}_step'.format(reader_net_name),
|
|
reader_net,
|
|
should_stop_blob=should_stop)
|
|
return (read_step, fields)
|
|
|
|
|
|
class Writer(object):
|
|
"""
|
|
Writer is a abstract class to be implemented in order to provide
|
|
operations capable of feeding a data stream or a dataset.
|
|
|
|
A Writer must implement 2 operations:
|
|
`write`, which adds operations to a net that write the write batch of
|
|
data, and `commit`, which adds operations to a net in order to indicate
|
|
that no more data will be written.
|
|
"""
|
|
|
|
_schema = None
|
|
|
|
def schema(self):
|
|
return self._schema
|
|
|
|
def write(self, writer_net, fields):
|
|
"""Add operations to `writer_net` that write the next batch of data.
|
|
|
|
Operations added to the net must be thread-safe and unique, that is:
|
|
multiple writers must be able to write to the dataset in parallel.
|
|
|
|
Args:
|
|
fields: a tuple of BlobReference containing the batch of data to
|
|
write.
|
|
"""
|
|
raise NotImplementedError('Writers must implement write.')
|
|
|
|
def write_record(self, writer_net, fields):
|
|
if isinstance(fields, Field):
|
|
self._schema = fields
|
|
fields = fields.field_blobs()
|
|
self.write(writer_net, fields)
|
|
|
|
def setup_ex(self, init_net, finish_net):
|
|
"""Experimental, don't use yet"""
|
|
self.commit(finish_net)
|
|
|
|
def write_ex(self, fields, local_init_net, local_finish_net, stop_blob):
|
|
"""Experimental extension to the interface. Don't use yet"""
|
|
write_net = core.Net('write_net')
|
|
self.write(write_net, fields)
|
|
return [write_net]
|
|
|
|
def write_record_ex(
|
|
self, fields, local_init_net, local_finish_net, stop_blob=None):
|
|
"""Experimental extension to the interface. Don't use yet."""
|
|
if isinstance(fields, Field):
|
|
self._schema = fields
|
|
fields = fields.field_blobs()
|
|
if stop_blob is None:
|
|
stop_blob = local_init_net.NextName("dequeue_status")
|
|
write_nets = self.write_ex(
|
|
fields, local_init_net, local_finish_net, stop_blob)
|
|
return (write_nets, stop_blob)
|
|
|
|
def commit(self, finish_net):
|
|
"""Add operations to `finish_net` that signal end of data.
|
|
|
|
This must be implemented by all Writers, but may be no-op for some
|
|
of them.
|
|
"""
|
|
pass
|
|
|
|
|
|
class ReaderBuilder(object):
|
|
""" Allow usage of a reader in distributed fashion. """
|
|
def schema(self):
|
|
raise NotImplementedError()
|
|
|
|
def enqueue_splits(self, net, split_queue):
|
|
raise NotImplementedError()
|
|
|
|
def splits(self, net):
|
|
raise NotImplementedError()
|
|
|
|
def new_reader(self, split_queue):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class Pipe(object):
|
|
def __init__(self, schema=None, obj_key=None):
|
|
self._num_writers = 0
|
|
self._num_readers = 0
|
|
self._schema = schema
|
|
self._obj_key = obj_key
|
|
|
|
def schema(self):
|
|
return self._schema
|
|
|
|
def setup(self, global_init_net):
|
|
pass
|
|
|
|
def reader(self):
|
|
raise NotImplementedError()
|
|
|
|
def writer(self):
|
|
raise NotImplementedError()
|
|
|
|
def num_readers(self):
|
|
return self._num_readers
|
|
|
|
def num_writers(self):
|
|
return self._num_writers
|
|
|
|
def _new_writer(self, writer_schema, writer_init_net):
|
|
if writer_schema is not None and self._schema is None:
|
|
self._schema = writer_schema
|
|
self._num_writers += 1
|
|
if self._obj_key is not None:
|
|
writer_init_net.add_attribute(self._obj_key, self)
|
|
|
|
def _new_reader(self, reader_init_net):
|
|
self._num_readers += 1
|
|
if self._obj_key is not None:
|
|
reader_init_net.add_attribute(self._obj_key, self)
|
|
|
|
|
|
class CounterReader(Reader):
|
|
""" Reader that produces increasing integers. """
|
|
def __init__(self):
|
|
Reader.__init__(self, schema=Struct(('iter', np.int64)))
|
|
self.counter = None
|
|
self.should_stop = None
|
|
|
|
def setup_ex(self, global_init_net, global_finish_net):
|
|
if self.counter is None:
|
|
self.counter = global_init_net.CreateCounter([], init_count=0)
|
|
self.should_stop = global_init_net.ConstantFill(
|
|
[], shape=[], dtype=core.DataType.BOOL, value=False)
|
|
|
|
def read_ex(self, local_init_net, local_finish_net):
|
|
count_net = core.Net('limited_reader_counter')
|
|
value = count_net.CountUp([self.counter], 1)
|
|
return [count_net], self.should_stop, [value]
|
|
|
|
|
|
class ReaderWithLimit(Reader):
|
|
""" Reader that stops after `num_iter` calls. """
|
|
def __init__(self, reader, num_iter=1):
|
|
Reader.__init__(self, schema=reader._schema)
|
|
self.reader = reader
|
|
self.counter = None
|
|
self.num_iter = num_iter
|
|
net = core.Net('reader_with_limit')
|
|
self._data_finished = net.AddExternalInput(
|
|
net.NextName('data_finished'))
|
|
self.counter = net.AddExternalInput(net.NextName('counter'))
|
|
|
|
def setup_ex(self, global_init_net, global_finish_net):
|
|
global_init_net.CreateCounter(
|
|
[], [self.counter], init_count=int(self.num_iter))
|
|
self.reader.setup_ex(global_init_net, global_finish_net)
|
|
global_init_net.ConstantFill(
|
|
[], [self._data_finished],
|
|
shape=[], value=False, dtype=core.DataType.BOOL)
|
|
|
|
def read_ex(self, local_init_net, local_finish_net):
|
|
""" 1. check if we reached number of iterations """
|
|
count_net = core.Net('limited_reader_counter')
|
|
should_stop = count_net.CountDown([self.counter], 1)
|
|
|
|
""" 2. call original reader """
|
|
nets, local_data_finished, fields = self.reader.read_ex(
|
|
local_init_net, local_finish_net)
|
|
self._set_schema(self.reader._schema)
|
|
|
|
""" 3. check if original reader is done. """
|
|
check_done_net = core.Net('limited_reader_post')
|
|
check_done_net.Copy(local_data_finished, should_stop)
|
|
check_done_net.Copy([local_data_finished], [self._data_finished])
|
|
|
|
# this relies on `should_stop` being called after each net.
|
|
return [count_net] + nets + [check_done_net], should_stop, fields
|
|
|
|
def data_finished(self):
|
|
"""
|
|
Return a blob that can be checked after the end of the reading task,
|
|
which will contain a scalar float indicating whether the underlying
|
|
reader has been exhausted (True) or whether we stopped because reached
|
|
the limit of iterations (False).
|
|
"""
|
|
return self._data_finished
|
|
|
|
|
|
def CountUntil(num_iter):
|
|
return ReaderWithLimit(CounterReader(), num_iter)
|