pytorch/caffe2/python/layers/batch_softmax_loss.py
Yangxin Zhong ed788ec780 Linearizable Label: Class Weights, Allow Missing Label, and Average by Batch Size (#29707)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29707

In D17885977, Linearizable label (a multi-class classification) was implemented in MTML.

In this diff, we add several items for Linearizable label:

- Assigning different weights to each class through ```model_def.tasks[i].class_weights```.

  - This option is a dictionary, the keys of which are indices of the classes and the values of which are weights for each class.

  - For example, if a linearizable-label task has 4 classes and its ```class_weights = {"0": 1, "1": 0.1, "2": 0.1, "3": 0.01}```, it means that in the loss function of this task, we assign weight 1 to its first class, weight 0.1 to its second and third class, and weight 0.01 to its forth class. The index/order of classes follows the logic of linearizable label.

  - Note that when you assign different weights to different classes, you need to correct the calibration by setting an appropriate ```model_def.tasks[i].calibration.linearizable_class_weight```. Basically, the class weights in calibration should be the reciprocals of the class weights in loss function. So the ```calibration.linearizable_class_weight = {"0": 1, "1": 10, "2": 10, "3": 100}``` for the example above.

  - Example FBLearner job: f150763093

- We also support ```model_def.allow_missing_label_with_zero_weight``` for linearizable label, which will ignore those examples with first label missing, by assigning zero weights to them in loss function.

  - We need to set ```allow_missing_label_with_zero_weight = true``` to enable it.

  - Example FBLearner job: f150763093

- Last but not least, we update caffe2 operator ```SoftmaxWithLoss``` to support loss averaged by batch size.

  - We need to set ```model_def.tasks[i].loss.softmaxLoss.average_by_batch_size = true``` to enable it.

  - Previously, the loss was averaged by weight sum of examples in batch, which is still the default behavior now (when ```average_by_batch_size = null``` or ```average_by_batch_size = false```).

  - Without this new feature, the calibration will be incorrect when applying non-equal-weight training among different classes to a linearizable task.

  - Example FBLearner job with ```average_by_batch_size = true``` results in a correct calibration: f150763093

  - Example FBLearner job with ```average_by_batch_size = null``` results in an incorrect calibration: f150762990

Test Plan:
buck test caffe2/caffe2/fb/dper/layer_models/tests:mtml_test_2 -- test_linearizable_label_task_with_class_weights
buck test caffe2/caffe2/fb/dper/layer_models/tests:mtml_test_2 -- test_linearizable_label_task_with_zero_weight
buck test caffe2/caffe2/fb/dper/layer_models/tests:mtml_test_2 -- test_linearizable_label_task_average_by_batch_size

All tests passed.

full canary: https://fburl.com/fblearner/troznfgh

Reviewed By: chenshouyuan

Differential Revision: D18461163

fbshipit-source-id: aaf3df031406ae94f74e2e365b57e47409ef0bfe
2019-11-13 16:52:27 -08:00

129 lines
4.6 KiB
Python

## @package batch_softmax_loss
# Module caffe2.python.layers.batch_softmax_loss
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, schema
from caffe2.python.layers.layers import ModelLayer
import numpy as np
class BatchSoftmaxLoss(ModelLayer):
def __init__(
self,
model,
input_record,
name='batch_softmax_loss',
label_smoothing_matrix=None,
label_prob=False,
scale=1.0,
average_by_batch_size=False,
**kwargs
):
super(BatchSoftmaxLoss, self).__init__(
model, name, input_record, **kwargs)
assert schema.is_schema_subset(
schema.Struct(
('label', schema.Scalar()),
('prediction', schema.Scalar()),
),
input_record
)
self.label_prob = label_prob
self.scale = scale
self.average_by_batch_size = average_by_batch_size
# label smoothing matrix: a K * K matrix where K is the label
# cardinality; (i, j) element is the value of for label i
# treated/smoothed as label j
self.label_smoothing_matrix = label_smoothing_matrix
if self.label_smoothing_matrix is not None:
self.initialize_label_smoothing_constants()
self.output_schema = schema.Struct(
(
'softmax', schema.Scalar(
input_record.prediction.field_type(),
self.get_next_blob_reference('softmax')
)
),
(
'loss', schema.Scalar(
np.float32, self.get_next_blob_reference('loss')
)
),
)
def initialize_label_smoothing_constants(self):
assert self.label_smoothing_matrix is not None
self.label_smoothing_matrix = np.array(
self.label_smoothing_matrix).astype(np.float32)
assert len(self.label_smoothing_matrix.shape) == 2
label_dim = self.label_smoothing_matrix.shape[0]
assert label_dim == self.label_smoothing_matrix.shape[1]
self.label_smoothing_matrix = self.model.add_global_constant(
'%s_label_smoothing_matrix' % self.name,
array=self.label_smoothing_matrix,
dtype=np.dtype(np.float32),
)
self.label_dim = self.model.add_global_constant(
'%s_label_dim' % self.name,
array=label_dim,
dtype=np.dtype(np.int64),
)
# default case: label is given NOT as target distribution
# but when used in label smoothing, the label must be in probabilities
self.label_prob = True
def compute_smoothed_label(self, net):
assert self.label_smoothing_matrix is not None
label = self.input_record.label()
original_label_type = self.input_record.label.field_type()
if original_label_type.base != np.int64:
int64_label = net.NextScopedBlob('int64_label')
net.Cast([label], [int64_label], to=core.DataType.INT64)
else:
int64_label = label
one_hot_label = net.NextScopedBlob('one_hot_label')
smoothed_label = net.NextScopedBlob('smoothed_label')
net.OneHot([int64_label, self.label_dim], [one_hot_label])
net.MatMul([one_hot_label, self.label_smoothing_matrix], smoothed_label)
return smoothed_label
def add_ops(self, net):
label = self.input_record.label.field_blobs()
if self.label_smoothing_matrix is not None:
label = [self.compute_smoothed_label(net)]
elif not self.label_prob:
if self.input_record.label.field_types()[0].base != np.int32:
label = [
net.Cast(label,
net.NextScopedBlob('int32_label'),
to=core.DataType.INT32)
]
softmax_input = self.input_record.prediction.field_blobs() + label
if 'weight' in self.input_record:
weight_blob = self.input_record.weight()
if self.input_record.weight.field_type().base != np.float32:
weight_blob = net.Cast(
weight_blob,
weight_blob + '_float32',
to=core.DataType.FLOAT
)
softmax_input += [weight_blob]
net.SoftmaxWithLoss(
softmax_input,
self.output_schema.field_blobs(),
label_prob=self.label_prob,
scale=self.scale,
average_by_batch_size=self.average_by_batch_size,
)