mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
283 lines
11 KiB
Python
283 lines
11 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python.dataio import Reader, ReaderWithLimit, ReaderWithTimeLimit
|
|
from caffe2.python.dataset import Dataset
|
|
from caffe2.python.pipeline import pipe
|
|
from caffe2.python.schema import Struct, NewRecord, FeedRecord
|
|
from caffe2.python.session import LocalSession
|
|
from caffe2.python.task import TaskGroup, final_output, WorkspaceType
|
|
from caffe2.python.test_util import TestCase
|
|
from caffe2.python.cached_reader import CachedReader
|
|
from caffe2.python import core, workspace
|
|
from caffe2.python.net_builder import ops
|
|
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
|
|
|
|
def init_dataset(ws, size=100):
|
|
src_init = core.Net('src_init')
|
|
with core.NameScope('src'):
|
|
src_values = Struct(('label', np.array(range(size))))
|
|
src_blobs = NewRecord(src_init, src_values)
|
|
src_ds = Dataset(src_blobs)
|
|
FeedRecord(src_blobs, src_values, ws)
|
|
ws.run(src_init)
|
|
return src_ds
|
|
|
|
|
|
def read_all_data(ws, reader, session):
|
|
dst_init = core.Net('dst_init')
|
|
with core.NameScope('dst'):
|
|
dst_ds = Dataset(reader.schema().clone_schema())
|
|
dst_ds.init_empty(dst_init)
|
|
session.run(dst_init)
|
|
|
|
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
|
|
pipe(reader, dst_ds.writer(), num_runtime_threads=8)
|
|
session.run(tg)
|
|
|
|
return ws.blobs[str(dst_ds.content().label())].fetch()
|
|
|
|
|
|
class ReaderWithDelay(Reader):
|
|
"""Test reader class that inserts a delay between reading batches."""
|
|
def __init__(self, reader, delay):
|
|
Reader.__init__(self, schema=reader._schema)
|
|
self.reader = reader
|
|
self.delay = delay
|
|
|
|
def setup_ex(self, global_init_net, global_finish_net):
|
|
self.reader.setup_ex(global_init_net, global_finish_net)
|
|
|
|
def read_ex(self, local_init_net, local_finish_net):
|
|
read_net = core.Net('reader_body')
|
|
|
|
def sleep_op(*args, **argd):
|
|
time.sleep(self.delay)
|
|
|
|
read_net.Python(sleep_op)([], [])
|
|
return ([read_net], ) + self.reader.read(read_net)
|
|
|
|
|
|
class TestReaderWithLimit(TestCase):
|
|
def test_runtime_threads(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
src_ds = init_dataset(ws)
|
|
totals = [None] * 3
|
|
|
|
def proc(rec):
|
|
# executed once
|
|
with ops.task_init():
|
|
counter1 = ops.CreateCounter([], ['global_counter'])
|
|
counter2 = ops.CreateCounter([], ['global_counter2'])
|
|
counter3 = ops.CreateCounter([], ['global_counter3'])
|
|
# executed once per thread
|
|
with ops.task_instance_init():
|
|
task_counter = ops.CreateCounter([], ['task_counter'])
|
|
# executed on each iteration
|
|
ops.CountUp(counter1)
|
|
ops.CountUp(task_counter)
|
|
# executed once per thread
|
|
with ops.task_instance_exit():
|
|
with ops.loop(ops.RetrieveCount(task_counter)):
|
|
ops.CountUp(counter2)
|
|
ops.CountUp(counter3)
|
|
# executed once
|
|
with ops.task_exit():
|
|
totals[0] = final_output(ops.RetrieveCount(counter1))
|
|
totals[1] = final_output(ops.RetrieveCount(counter2))
|
|
totals[2] = final_output(ops.RetrieveCount(counter3))
|
|
return rec
|
|
|
|
# Read full data set from original reader
|
|
with TaskGroup() as tg:
|
|
pipe(src_ds.reader(), num_runtime_threads=8, processor=proc)
|
|
session.run(tg)
|
|
self.assertEqual(totals[0].fetch(), 100)
|
|
self.assertEqual(totals[1].fetch(), 100)
|
|
self.assertEqual(totals[2].fetch(), 8)
|
|
|
|
# Read with a count-limited reader
|
|
with TaskGroup() as tg:
|
|
q1 = pipe(src_ds.reader(), num_runtime_threads=2)
|
|
q2 = pipe(
|
|
ReaderWithLimit(q1.reader(), num_iter=25),
|
|
num_runtime_threads=3)
|
|
pipe(q2, processor=proc, num_runtime_threads=6)
|
|
session.run(tg)
|
|
self.assertEqual(totals[0].fetch(), 25)
|
|
self.assertEqual(totals[1].fetch(), 25)
|
|
self.assertEqual(totals[2].fetch(), 6)
|
|
|
|
def _test_limit_reader_init_shared(self, size):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
|
|
# Build test dataset
|
|
src_ds = init_dataset(ws, size=size)
|
|
|
|
# Create an identically sized empty destnation dataset
|
|
dst_init = core.Net('dst_init')
|
|
with core.NameScope('dst'):
|
|
dst_ds = Dataset(src_ds.content().clone_schema())
|
|
dst_ds.init_empty(dst_init)
|
|
ws.run(dst_init)
|
|
|
|
return ws, session, src_ds, dst_init, dst_ds
|
|
|
|
def _test_limit_reader_shared(self, reader_class, size, expected_read_len,
|
|
expected_finish, num_threads, read_delay,
|
|
**limiter_args):
|
|
ws, session, src_ds, dst_init, dst_ds = \
|
|
self._test_limit_reader_init_shared(size)
|
|
|
|
# Read without limiter
|
|
# WorkspaceType.GLOBAL is required because we are fetching
|
|
# reader.data_finished() after the TaskGroup finishes.
|
|
with TaskGroup(workspace_type=WorkspaceType.GLOBAL) as tg:
|
|
if read_delay > 0:
|
|
reader = reader_class(ReaderWithDelay(src_ds.reader(),
|
|
read_delay),
|
|
**limiter_args)
|
|
else:
|
|
reader = reader_class(src_ds.reader(), **limiter_args)
|
|
pipe(reader, dst_ds.writer(), num_runtime_threads=num_threads)
|
|
session.run(tg)
|
|
read_len = len(sorted(ws.blobs[str(dst_ds.content().label())].fetch()))
|
|
self.assertEqual(read_len, expected_read_len)
|
|
self.assertEqual(
|
|
sorted(ws.blobs[str(dst_ds.content().label())].fetch()),
|
|
list(range(expected_read_len))
|
|
)
|
|
self.assertEqual(ws.blobs[str(reader.data_finished())].fetch(),
|
|
expected_finish)
|
|
|
|
def test_count_limit_reader_without_limit(self):
|
|
# No iter count specified, should read all records.
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=None)
|
|
|
|
def test_count_limit_reader_with_zero_limit(self):
|
|
# Zero iter count specified, should read 0 records.
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=0,
|
|
expected_finish=False,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=0)
|
|
|
|
def test_count_limit_reader_with_low_limit(self):
|
|
# Read with limit smaller than size of dataset
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=10,
|
|
expected_finish=False,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=10)
|
|
|
|
def test_count_limit_reader_with_high_limit(self):
|
|
# Read with limit larger than size of dataset
|
|
self._test_limit_reader_shared(ReaderWithLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0,
|
|
num_iter=110)
|
|
|
|
def test_time_limit_reader_without_limit(self):
|
|
# No duration specified, should read all records.
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=100,
|
|
expected_read_len=100,
|
|
expected_finish=True,
|
|
num_threads=8,
|
|
read_delay=0.1,
|
|
duration=0)
|
|
|
|
def test_time_limit_reader_with_short_limit(self):
|
|
# Read with insufficient time limit
|
|
size = 50
|
|
num_threads = 4
|
|
sleep_duration = 0.25
|
|
duration = 1
|
|
expected_read_len = int(round(num_threads * duration / sleep_duration))
|
|
# Because the time limit check happens before the delay + read op,
|
|
# subtract a little bit of time to ensure we don't get in an extra read
|
|
duration = duration - 0.25 * sleep_duration
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=size,
|
|
expected_read_len=expected_read_len,
|
|
expected_finish=False,
|
|
num_threads=num_threads,
|
|
read_delay=sleep_duration,
|
|
duration=duration)
|
|
|
|
def test_time_limit_reader_with_long_limit(self):
|
|
# Read with ample time limit
|
|
self._test_limit_reader_shared(ReaderWithTimeLimit,
|
|
size=50,
|
|
expected_read_len=50,
|
|
expected_finish=True,
|
|
num_threads=4,
|
|
read_delay=0.25,
|
|
duration=6)
|
|
|
|
def test_cached_reader(self):
|
|
ws = workspace.C.Workspace()
|
|
session = LocalSession(ws)
|
|
|
|
def build_source_reader(size):
|
|
src_ds = init_dataset(ws, size)
|
|
return src_ds.reader()
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False) as f:
|
|
path = f.name
|
|
f.close()
|
|
os.remove(path)
|
|
|
|
# Read data for the first time.
|
|
cached_reader1 = CachedReader(build_source_reader(100))
|
|
init_step = cached_reader1.build_cache(path)
|
|
session.run(init_step)
|
|
|
|
data = read_all_data(ws, cached_reader1, session)
|
|
self.assertEqual(sorted(data), list(range(100)))
|
|
|
|
# Read data from cache.
|
|
workspace.ResetWorkspace()
|
|
cached_reader2 = CachedReader(build_source_reader(200))
|
|
init_step = cached_reader2.build_cache(path)
|
|
session.run(init_step)
|
|
|
|
data = read_all_data(ws, cached_reader2, session)
|
|
self.assertEqual(sorted(data), list(range(100)))
|
|
|
|
shutil.rmtree(path)
|
|
|
|
# We removed cache so we expect to receive data from original reader
|
|
workspace.ResetWorkspace()
|
|
cached_reader3 = CachedReader(build_source_reader(300))
|
|
init_step = cached_reader3.build_cache(path)
|
|
session.run(init_step)
|
|
|
|
data = read_all_data(ws, cached_reader3, session)
|
|
self.assertEqual(sorted(data), list(range(300)))
|
|
|
|
shutil.rmtree(path)
|