mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Segment based Ops requires increasing seg id, and without gap. Lengths based Ops does not have this requirements. Otherpooling methods, e.g., LogExpMean does not have Lengths based Ops available yet. Differential Revision: D5019165 fbshipit-source-id: ab01a220e10d4ed9fa2162939579d346607f905e
179 lines
6.8 KiB
Python
179 lines
6.8 KiB
Python
## @package sparse_lookup
|
|
# Module caffe2.python.layers.sparse_lookup
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from caffe2.python import core, schema
|
|
from caffe2.python.layers.layers import (
|
|
get_categorical_limit,
|
|
IdList,
|
|
IdScoreList,
|
|
LayerParameter,
|
|
LayerPsParam,
|
|
ModelLayer,
|
|
)
|
|
import functools
|
|
import math
|
|
import numpy as np
|
|
import operator
|
|
|
|
|
|
class SparseLookup(ModelLayer):
|
|
_supported_reducers = ['PositionWeighted', 'LogMeanExp', 'LogSumExp', 'Max',
|
|
'Mean', 'Sum', 'Sqrt']
|
|
|
|
def __init__(self, model, input_record, inner_shape, reducer,
|
|
weight_init=None, weight_optim=None,
|
|
name='sparse_lookup', **kwargs):
|
|
super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
|
|
|
|
if isinstance(inner_shape, int):
|
|
inner_shape = [inner_shape]
|
|
assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\
|
|
"Unexpected type for inner_shape, expected list or tuple, got {0}".\
|
|
format(type(inner_shape))
|
|
|
|
# TODO Add some asserts about input type
|
|
assert reducer in self._supported_reducers, "Unsupported reducer: {}".\
|
|
format(reducer)
|
|
self.reducer = reducer
|
|
|
|
input_dim = get_categorical_limit(input_record)
|
|
|
|
assert input_dim is not None, "Unbounded features are not supported"
|
|
|
|
self.output_schema = schema.Scalar(
|
|
(np.float32, inner_shape),
|
|
model.net.NextScopedBlob(name + '_output'),
|
|
)
|
|
|
|
scale = math.sqrt(1.0 / input_dim)
|
|
self.shape = [input_dim] + inner_shape
|
|
self.weight_init = weight_init if weight_init else (
|
|
'UniformFill', {'min': -scale, 'max': scale})
|
|
|
|
self.w = model.net.NextScopedBlob(name + "_w")
|
|
if schema.equal_schemas(self.input_record, IdList):
|
|
sparse_key = self.input_record.items()
|
|
elif schema.equal_schemas(self.input_record, IdScoreList):
|
|
sparse_key = self.input_record.keys()
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
if self.input_record.lengths.metadata:
|
|
avg_length = self.input_record.lengths.metadata.expected_value
|
|
else:
|
|
avg_length = None
|
|
self.params.append(
|
|
LayerParameter(
|
|
parameter=self.w,
|
|
initializer=core.CreateOperator(self.weight_init[0],
|
|
[],
|
|
self.w,
|
|
shape=self.shape,
|
|
**self.weight_init[1]
|
|
),
|
|
optimizer=weight_optim,
|
|
ps_param=LayerPsParam(
|
|
sparse_key=sparse_key,
|
|
average_length=avg_length
|
|
)
|
|
))
|
|
|
|
if reducer == 'PositionWeighted':
|
|
self.pos_w = model.net.NextScopedBlob(name + "_pos_w")
|
|
self.params.append(
|
|
LayerParameter(
|
|
parameter=self.pos_w,
|
|
initializer=core.CreateOperator('ConstantFill',
|
|
[],
|
|
self.pos_w,
|
|
shape=[input_dim, ],
|
|
value=1.0
|
|
),
|
|
optimizer=weight_optim
|
|
))
|
|
|
|
def get_memory_usage(self):
|
|
return functools.reduce(operator.mul, self.shape) * 4
|
|
|
|
def get_fp16_compatible_parameters(self):
|
|
return [self.w]
|
|
|
|
def add_ops(self, net):
|
|
if schema.equal_schemas(self.input_record, IdList):
|
|
if self.reducer in ['Sum', 'Mean']:
|
|
net.__getattr__('SparseLengths' + self.reducer)(
|
|
[
|
|
self.w,
|
|
self.input_record.items(),
|
|
self.input_record.lengths()
|
|
],
|
|
self.output_schema.field_blobs(),
|
|
engine='fp16'
|
|
)
|
|
elif self.reducer == 'PositionWeighted':
|
|
inc_seq = net.LengthsRangeFill(
|
|
[self.input_record.lengths()],
|
|
self.input_record.lengths() + '_seq'
|
|
)
|
|
gather_pos_w = net.Gather(
|
|
[self.pos_w, inc_seq], self.pos_w + '_gather')
|
|
|
|
net.SparseLengthsWeightedSum(
|
|
[
|
|
self.w,
|
|
gather_pos_w,
|
|
self.input_record.items(),
|
|
self.input_record.lengths()
|
|
],
|
|
self.output_schema.field_blobs(),
|
|
grad_on_weights=1,
|
|
engine='fp16'
|
|
)
|
|
elif self.reducer == 'Sqrt':
|
|
sqrt_weight = net.LengthsToWeights(
|
|
[self.input_record.lengths()],
|
|
[self.input_record.lengths() + '_sqrt'],
|
|
power=0.5
|
|
)
|
|
net.SparseLengthsWeightedSum(
|
|
[
|
|
self.w,
|
|
sqrt_weight,
|
|
self.input_record.items(),
|
|
self.input_record.lengths()
|
|
],
|
|
self.output_schema.field_blobs(),
|
|
engine='fp16'
|
|
)
|
|
else:
|
|
table_rows = net.Gather([self.w, self.input_record.items()])
|
|
segment_ids = net.LengthsToSegmentIds(
|
|
self.input_record.lengths(),
|
|
self.input_record.lengths() + '_sid')
|
|
net.__getattr__('SortedSegmentRange' + self.reducer)(
|
|
[table_rows, segment_ids],
|
|
self.output_schema.field_blobs(),
|
|
engine='fp16'
|
|
)
|
|
elif schema.equal_schemas(self.input_record, IdScoreList):
|
|
if self.reducer in ['Sum', 'Mean']:
|
|
net.__getattr__('SparseLengthsWeighted' + self.reducer)(
|
|
[
|
|
self.w,
|
|
self.input_record.values(),
|
|
self.input_record.keys(),
|
|
self.input_record.lengths()
|
|
],
|
|
self.output_schema.field_blobs(),
|
|
engine='fp16'
|
|
)
|
|
else:
|
|
raise "Only Sum, Mean is supported for IdScoreList input." +\
|
|
"Trying to create with {}".format(self.reducer)
|
|
else:
|
|
raise "Unsupported input type {0}".format(self.input_record)
|