## @package bucket_weighted # Module caffe2.python.layers.bucket_weighted import logging import numpy as np from caffe2.python import core, schema from caffe2.python.layers.layers import ( get_categorical_limit, ModelLayer, ) from caffe2.python.layers.tags import Tags logger = logging.getLogger(__name__) class BucketWeighted(ModelLayer): def __init__(self, model, input_record, max_score=0, bucket_boundaries=None, hash_buckets=True, weight_optim=None, name="bucket_weighted"): super(BucketWeighted, self).__init__(model, name, input_record) assert isinstance(input_record, schema.List), "Incorrect input type" self.bucket_boundaries = bucket_boundaries self.hash_buckets = hash_buckets if bucket_boundaries is not None: self.shape = len(bucket_boundaries) + 1 elif max_score > 0: self.shape = max_score else: self.shape = get_categorical_limit(input_record) self.bucket_w = self.create_param(param_name='bucket_w', shape=[self.shape, ], initializer=('ConstantFill', {'value': 1.0}), optimizer=weight_optim) self.output_schema = schema.Struct( ('bucket_weights', schema.Scalar((np.float32, self.shape), self.get_next_blob_reference("bucket_w_gather"))) ) self.tags.update({Tags.HANDLE_AS_SPARSE_LAYER}) def get_memory_usage(self): return self.shape def add_ops(self, net): if self.bucket_boundaries is not None: buckets_int = net.Bucketize( self.input_record.values(), "buckets_int", boundaries=self.bucket_boundaries ) else: buckets = self.input_record.values() buckets_int = net.Cast( buckets, "buckets_int", to=core.DataType.INT32 ) if self.hash_buckets: buckets_int = net.IndexHash( buckets_int, "hashed_buckets_int", seed=0, modulo=self.shape ) net.Gather( [self.bucket_w, buckets_int], self.output_schema.bucket_weights.field_blobs())