mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
LastNWindowCollector
Summary: Layer for LastNWindowCollector op. We need this since it's an in-place operator. Reviewed By: chocjy Differential Revision: D4981772 fbshipit-source-id: ec85dbf247d0944db422ad396771fa9308650883
This commit is contained in:
parent
b229b7ff11
commit
211eae127c
67
caffe2/python/layers/last_n_window_collector.py
Normal file
67
caffe2/python/layers/last_n_window_collector.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
## @package last_n_window_collector
|
||||
# Module caffe2.python.layers.last_n_window_collector
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core, schema
|
||||
from caffe2.python.layers.layers import (
|
||||
LayerParameter,
|
||||
ModelLayer,
|
||||
)
|
||||
|
||||
|
||||
class LastNWindowCollector(ModelLayer):
|
||||
"""
|
||||
Collect last-N samples from input record. If you have complex data,
|
||||
use PackRecords to pack it before using this layer.
|
||||
|
||||
This layer is not thread safe.
|
||||
"""
|
||||
|
||||
def __init__(self, model, input_record, num_to_collect,
|
||||
name='last_n_window_collector', **kwargs):
|
||||
super(LastNWindowCollector, self).__init__(
|
||||
model, name, input_record, **kwargs)
|
||||
assert num_to_collect > 0
|
||||
self.num_to_collect = num_to_collect
|
||||
assert isinstance(input_record, schema.Scalar), \
|
||||
"Got {!r}".format(input_record)
|
||||
|
||||
self.last_n = model.net.NextScopedBlob(self.name + "_last_n")
|
||||
self.next_blob = model.net.NextScopedBlob(self.name + "_next")
|
||||
|
||||
self.params.append(LayerParameter(
|
||||
parameter=self.last_n,
|
||||
initializer=core.CreateOperator(
|
||||
'ConstantFill', [], self.last_n, shape=[0]
|
||||
),
|
||||
optimizer=model.NoOptim,
|
||||
))
|
||||
self.params.append(LayerParameter(
|
||||
parameter=self.next_blob,
|
||||
initializer=core.CreateOperator(
|
||||
'ConstantFill',
|
||||
[],
|
||||
self.next_blob,
|
||||
shape=[],
|
||||
value=0,
|
||||
dtype=core.DataType.INT32,
|
||||
),
|
||||
optimizer=model.NoOptim,
|
||||
))
|
||||
|
||||
self.output_schema = schema.from_blob_list(
|
||||
input_record, [model.net.NextScopedBlob(name + "_output")])
|
||||
|
||||
def add_ops(self, net):
|
||||
net.LastNWindowCollector(
|
||||
[self.last_n, self.next_blob, self.input_record()],
|
||||
[self.last_n, self.next_blob],
|
||||
num_to_collect=self.num_to_collect,
|
||||
)
|
||||
# Copy to make sure DAG of record is not broken.
|
||||
# Also, the output of this is likely going through a pipeline, which
|
||||
# will move data and require us to copy anyway.
|
||||
net.Copy(self.last_n, self.output_schema())
|
||||
|
|
@ -3,9 +3,14 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
|
||||
from hypothesis import given
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
from caffe2.python import (
|
||||
layer_model_instantiator,
|
||||
schema,
|
||||
|
|
@ -190,6 +195,19 @@ class TestLayers(LayersTestCase):
|
|||
('loss', schema.Scalar(np.float32)),
|
||||
), loss)
|
||||
|
||||
@given(
|
||||
X=hu.arrays(dims=[5, 2]),
|
||||
num_to_collect=st.integers(min_value=1, max_value=10),
|
||||
)
|
||||
def testLastNWindowCollector(self, X, num_to_collect):
|
||||
input_record = self.new_record(schema.Scalar(np.float32))
|
||||
schema.FeedRecord(input_record, [X])
|
||||
last_n = self.model.LastNWindowCollector(input_record, num_to_collect)
|
||||
self.run_train_net_forward_only()
|
||||
output_record = schema.FetchRecord(last_n)
|
||||
start = max(0, 5 - num_to_collect)
|
||||
npt.assert_array_equal(X[start:], output_record())
|
||||
|
||||
def testUniformSampling(self):
|
||||
input_record = self.new_record(schema.Scalar(np.int32))
|
||||
input_array = np.array([3, 10, 11, 15, 20, 99], dtype=np.int32)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user