LastNWindowCollector

Summary: Layer for LastNWindowCollector op. We need this since it's an in-place operator.

Reviewed By: chocjy

Differential Revision: D4981772

fbshipit-source-id: ec85dbf247d0944db422ad396771fa9308650883
This commit is contained in:
Kittipat Virochsiri 2017-05-04 17:19:40 -07:00 committed by Facebook Github Bot
parent b229b7ff11
commit 211eae127c
2 changed files with 85 additions and 0 deletions

View File

@ -0,0 +1,67 @@
## @package last_n_window_collector
# Module caffe2.python.layers.last_n_window_collector
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, schema
from caffe2.python.layers.layers import (
LayerParameter,
ModelLayer,
)
class LastNWindowCollector(ModelLayer):
"""
Collect last-N samples from input record. If you have complex data,
use PackRecords to pack it before using this layer.
This layer is not thread safe.
"""
def __init__(self, model, input_record, num_to_collect,
name='last_n_window_collector', **kwargs):
super(LastNWindowCollector, self).__init__(
model, name, input_record, **kwargs)
assert num_to_collect > 0
self.num_to_collect = num_to_collect
assert isinstance(input_record, schema.Scalar), \
"Got {!r}".format(input_record)
self.last_n = model.net.NextScopedBlob(self.name + "_last_n")
self.next_blob = model.net.NextScopedBlob(self.name + "_next")
self.params.append(LayerParameter(
parameter=self.last_n,
initializer=core.CreateOperator(
'ConstantFill', [], self.last_n, shape=[0]
),
optimizer=model.NoOptim,
))
self.params.append(LayerParameter(
parameter=self.next_blob,
initializer=core.CreateOperator(
'ConstantFill',
[],
self.next_blob,
shape=[],
value=0,
dtype=core.DataType.INT32,
),
optimizer=model.NoOptim,
))
self.output_schema = schema.from_blob_list(
input_record, [model.net.NextScopedBlob(name + "_output")])
def add_ops(self, net):
net.LastNWindowCollector(
[self.last_n, self.next_blob, self.input_record()],
[self.last_n, self.next_blob],
num_to_collect=self.num_to_collect,
)
# Copy to make sure DAG of record is not broken.
# Also, the output of this is likely going through a pipeline, which
# will move data and require us to copy anyway.
net.Copy(self.last_n, self.output_schema())

View File

@ -3,9 +3,14 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import hypothesis.strategies as st
import numpy as np
import numpy.testing as npt
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
from caffe2.python import (
layer_model_instantiator,
schema,
@ -190,6 +195,19 @@ class TestLayers(LayersTestCase):
('loss', schema.Scalar(np.float32)),
), loss)
@given(
X=hu.arrays(dims=[5, 2]),
num_to_collect=st.integers(min_value=1, max_value=10),
)
def testLastNWindowCollector(self, X, num_to_collect):
input_record = self.new_record(schema.Scalar(np.float32))
schema.FeedRecord(input_record, [X])
last_n = self.model.LastNWindowCollector(input_record, num_to_collect)
self.run_train_net_forward_only()
output_record = schema.FetchRecord(last_n)
start = max(0, 5 - num_to_collect)
npt.assert_array_equal(X[start:], output_record())
def testUniformSampling(self):
input_record = self.new_record(schema.Scalar(np.int32))
input_array = np.array([3, 10, 11, 15, 20, 99], dtype=np.int32)