pytorch/caffe2/python/layers/position_weighted.py
Wael Abdelghani c291c97494 Add integration test for pos_w
Summary: Title

Reviewed By: azzolini

Differential Revision: D5197307

fbshipit-source-id: 425bf8e7c5068ea544e5b2709b6bb27eef140bf3
2017-06-08 18:04:53 -07:00

61 lines
2.0 KiB
Python

## @package position_weighted
# Module caffe2.python.layers.position_weighted
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, schema
from caffe2.python.layers.layers import (
LayerParameter,
ModelLayer,
)
from caffe2.python.layers.tags import Tags
import numpy as np
class PositionWeighted(ModelLayer):
def __init__(self, model, input_record, weight_optim=None,
name="position_weights"):
super(PositionWeighted, self).__init__(model, name, input_record)
# TODO: Replace this with correct estimation after we compute
# cardinality from run_meta
self.shape = 2000
self.pos_w = model.net.NextScopedBlob(name + "_pos_w")
self.params.append(
LayerParameter(
parameter=self.pos_w,
initializer=core.CreateOperator('ConstantFill',
[],
self.pos_w,
shape=[self.shape, ],
value=1.0
),
optimizer=weight_optim
))
self.output_schema = schema.Struct(
('position_weights',
schema.Scalar((np.float32, self.shape),
model.net.NextScopedBlob(name + "_pos_w_gather")))
)
self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
self.tags.update({Tags.GRADIENT_FROM_PS})
def get_memory_usage(self):
return self.shape
def add_ops(self, net):
inc_seq = net.LengthsRangeFill(
[self.input_record.lengths()],
self.input_record.lengths() + '_pos_w_seq'
)
net.Gather(
[self.pos_w, inc_seq],
self.output_schema.position_weights.field_blobs())