mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Title Reviewed By: azzolini Differential Revision: D5197307 fbshipit-source-id: 425bf8e7c5068ea544e5b2709b6bb27eef140bf3
61 lines
2.0 KiB
Python
61 lines
2.0 KiB
Python
## @package position_weighted
|
|
# Module caffe2.python.layers.position_weighted
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, schema
|
|
from caffe2.python.layers.layers import (
|
|
LayerParameter,
|
|
ModelLayer,
|
|
)
|
|
|
|
from caffe2.python.layers.tags import Tags
|
|
import numpy as np
|
|
|
|
|
|
class PositionWeighted(ModelLayer):
|
|
def __init__(self, model, input_record, weight_optim=None,
|
|
name="position_weights"):
|
|
super(PositionWeighted, self).__init__(model, name, input_record)
|
|
|
|
# TODO: Replace this with correct estimation after we compute
|
|
# cardinality from run_meta
|
|
self.shape = 2000
|
|
|
|
self.pos_w = model.net.NextScopedBlob(name + "_pos_w")
|
|
self.params.append(
|
|
LayerParameter(
|
|
parameter=self.pos_w,
|
|
initializer=core.CreateOperator('ConstantFill',
|
|
[],
|
|
self.pos_w,
|
|
shape=[self.shape, ],
|
|
value=1.0
|
|
),
|
|
optimizer=weight_optim
|
|
))
|
|
|
|
self.output_schema = schema.Struct(
|
|
('position_weights',
|
|
schema.Scalar((np.float32, self.shape),
|
|
model.net.NextScopedBlob(name + "_pos_w_gather")))
|
|
)
|
|
|
|
self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER})
|
|
self.tags.update({Tags.GRADIENT_FROM_PS})
|
|
|
|
def get_memory_usage(self):
|
|
return self.shape
|
|
|
|
def add_ops(self, net):
|
|
inc_seq = net.LengthsRangeFill(
|
|
[self.input_record.lengths()],
|
|
self.input_record.lengths() + '_pos_w_seq'
|
|
)
|
|
|
|
net.Gather(
|
|
[self.pos_w, inc_seq],
|
|
self.output_schema.position_weights.field_blobs())
|