mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This fixes a race condition in text_file_reader.py.
For example in `fbcode/caffe2/caffe2/fb/text/stats.py`, in `compute_meta`, we build an execution step `read` such as:
```
.
└── step_read
├── net_reader
│ ├── op_TextFileReaderRead
│ └── op_IsEmpty
└── net_consume:n
└── op_Tokenize
```
Note that in `workspace.cc`, we check should_stop between each _step_ and each _net_, not between _ops_
Let's say we have 2 workers, here is a faulty interleaving of threads:
- 1 executes TextFileReaderRead
- 2 executes TextFileReaderRead
- 1 executes IsEmpty and sets should_stop to False
- 2 executes IsEmpty and sets should_stop to True
- 1 checks should_stop before running net_consume:n
- 1 stops
- 2 checks should_stop before running net_consume:n
- 2 stops
That's an issue, because 1 did read data from the file but did not run the processing step (consume:n) for this data.
Reviewed By: dzhulgakov
Differential Revision: D4203729
fbshipit-source-id: eabd94ea995527ec52fa137a8b63c277f7e4dd96
57 lines
2.0 KiB
Python
57 lines
2.0 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
from caffe2.python import core
|
|
from caffe2.python.dataio import Reader
|
|
from caffe2.python.schema import Scalar, Struct, data_type_for_dtype
|
|
|
|
|
|
class TextFileReader(Reader):
|
|
"""
|
|
Wrapper around operators for reading from text files.
|
|
"""
|
|
def __init__(self, init_net, filename, schema, num_passes=1, batch_size=1):
|
|
"""
|
|
Create op for building a HiveReader instance in the workspace.
|
|
|
|
Args:
|
|
init_net : Net that will be run only once at startup.
|
|
filename : Path to file to read from.
|
|
schema : schema.Struct representing the schema of the data.
|
|
Currently, only support Struct of strings.
|
|
num_passes : Number of passes over the data.
|
|
batch_size : Number of rows to read at a time.
|
|
"""
|
|
assert isinstance(schema, Struct), 'Schema must be a schema.Struct'
|
|
for name, child in schema.get_children():
|
|
assert isinstance(child, Scalar), (
|
|
'Only scalar fields are supported in TextFileReader.')
|
|
field_types = [
|
|
data_type_for_dtype(dtype) for dtype in schema.field_types()]
|
|
Reader.__init__(self, schema)
|
|
self._reader = init_net.CreateTextFileReader(
|
|
[],
|
|
filename=filename,
|
|
num_passes=num_passes,
|
|
field_types=field_types)
|
|
self._batch_size = batch_size
|
|
|
|
def read(self, net):
|
|
"""
|
|
Create op for reading a batch of rows.
|
|
"""
|
|
blobs = net.TextFileReaderRead(
|
|
[self._reader],
|
|
len(self.schema().field_names()),
|
|
batch_size=self._batch_size)
|
|
if type(blobs) is core.BlobReference:
|
|
blobs = [blobs]
|
|
|
|
is_empty = net.IsEmpty(
|
|
[blobs[0]],
|
|
core.ScopedBlobReference(net.NextName('should_stop'))
|
|
)
|
|
|
|
return (is_empty, blobs)
|