# Copyright (c) 2016-present, Facebook, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ############################################################################## # @package label_smooth # Module caffe2.python.layers.label_smooth from caffe2.python import core, schema from caffe2.python.layers.layers import ModelLayer import numpy as np class LabelSmooth(ModelLayer): def __init__( self, model, label, smooth_matrix, name='label_smooth', **kwargs ): super(LabelSmooth, self).__init__(model, name, label, **kwargs) self.label = label # shape as a list smooth_matrix = np.array(smooth_matrix).astype(np.float32).flatten() self.set_dim(smooth_matrix) self.set_smooth_matrix(smooth_matrix) self.output_schema = schema.Scalar( (np.float32, (self.dim, )), self.get_next_blob_reference('smoothed_label') ) def set_dim(self, smooth_matrix): num_elements = smooth_matrix.size self.binary_prob_label = (num_elements == 2) if self.binary_prob_label: self.dim = 1 else: assert np.sqrt(num_elements)**2 == num_elements self.dim = int(np.sqrt(num_elements)) def set_smooth_matrix(self, smooth_matrix): if not self.binary_prob_label: self.smooth_matrix = self.model.add_global_constant( '%s_label_smooth_matrix' % self.name, array=smooth_matrix.reshape((self.dim, self.dim)), dtype=np.dtype(np.float32), ) self.len = self.model.add_global_constant( '%s_label_dim' % self.name, array=self.dim, dtype=np.dtype(np.int64), ) else: self.smooth_matrix = smooth_matrix def add_ops_for_binary_prob_label(self, net): if self.label.field_type().base != np.float32: float32_label = net.NextScopedBlob('float32_label') net.Cast([self.label()], [float32_label], to=core.DataType.FLOAT) else: float32_label = self.label() net.StumpFunc( float32_label, self.output_schema(), threshold=0.5, low_value=self.smooth_matrix[0], high_value=self.smooth_matrix[1], ) def add_ops_for_categorical_label(self, net): if self.label.field_type().base != np.int64: int64_label = net.NextScopedBlob('int64_label') net.Cast([self.label()], [int64_label], to=core.DataType.INT64) else: int64_label = self.label() one_hot_label = net.NextScopedBlob('one_hot_label') net.OneHot([int64_label, self.len], [one_hot_label]) net.MatMul([one_hot_label, self.smooth_matrix], self.output_schema()) def add_ops(self, net): if self.binary_prob_label: self.add_ops_for_binary_prob_label(net) else: self.add_ops_for_categorical_label(net)