pytorch/caffe2/python/layers/batch_normalization.py
Silun Wang 28c1258f18 Scale init for batch-norm and layer-norm (#31983)
Summary:
Per discussion with Fei Tian, we need to add a `scale_init_value` to scale down the output of normalization such as batch-norm and layer-norm.

Currently we have `sparse_normalization_options` to normalize embedding pooling output. By default, scale = 1.0, we found it's better to set scale from 0.025 to 0.1 https://fb.quip.com/MiKUAibEaYhH

Besides, I am removing the tags from normalizers because it makes more sense to calculate norm ops in distributed trainers, not ps.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31983

Test Plan:
Testing LN and BN after sum-pooling --
baseline f160348514
LN: f160348609
BN: f160348710

{F226106518}

Layer norm after sum-pooling fwd_net https://fburl.com/sa4j207n
Layer norm after dot-prod fwd_net https://fburl.com/twggwyvb

## Unit Tests
Testing normalization after pooling
```
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_4 -- test_sparse_pooling_batch_normalization
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_4 -- test_dense_sparse_pooling_batch_normalization
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_4 -- test_sparse_pooling_layer_normalization
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test_4 -- test_dense_sparse_pooling_layer_normalization
```

Testing normalization after dot-prod
```
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test -- test_last_layer_use_batch_norm
buck test caffe2/caffe2/fb/dper/layer_models/tests/split_1:sparse_nn_test -- test_last_layer_use_layer_norm
```

Differential Revision: D19277618

Pulled By: SilunWang

fbshipit-source-id: ea323e33e3647ba55d2e808ef09d94ad7b45b934
2020-01-10 11:55:56 -08:00

109 lines
3.9 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import schema
from caffe2.python.layers.layers import ModelLayer
import numpy as np
class BatchNormalization(ModelLayer):
def __init__(
self,
model,
input_record,
name='batch_normalization',
scale_optim=None,
bias_optim=None,
momentum=0.9,
order='NCHW',
scale_init_value=1.0,
**kwargs
):
super(BatchNormalization, self).__init__(
model, name, input_record, **kwargs)
assert isinstance(input_record, schema.Scalar), "Incorrect input type"
self.input_shape = input_record.field_type().shape
if len(self.input_shape) == 3:
if order == "NCHW":
input_dims = self.input_shape[0]
elif order == "NHWC":
input_dims = self.input_shape[2]
else:
raise ValueError("Please specify a correct order")
else:
assert len(self.input_shape) == 1, (
"This layer supports only 4D or 2D tesnors")
input_dims = self.input_shape[0]
self.output_schema = schema.Scalar(
(np.float32, self.input_shape),
self.get_next_blob_reference('output')
)
self.momentum = momentum
self.order = order
self.scale = self.create_param(param_name='scale',
shape=[input_dims],
initializer=('ConstantFill', {'value': scale_init_value}),
optimizer=scale_optim)
self.bias = self.create_param(param_name='bias',
shape=[input_dims],
initializer=('ConstantFill', {'value': 0.0}),
optimizer=bias_optim)
self.rm = self.create_param(param_name='running_mean',
shape=[input_dims],
initializer=('ConstantFill', {'value': 0.0}),
optimizer=model.NoOptim)
self.riv = self.create_param(param_name='running_inv_var',
shape=[input_dims],
initializer=('ConstantFill', {'value': 1.0}),
optimizer=model.NoOptim)
def _add_ops(self, net, is_test, out_blob=None):
original_input_blob = self.input_record.field_blobs()
input_blob = net.NextScopedBlob('expand_input')
if len(self.input_shape) == 1:
input_blob = net.ExpandDims(original_input_blob,
dims=[2, 3])
else:
input_blob = original_input_blob[0]
if out_blob is None:
bn_output = self.output_schema.field_blobs()
else:
bn_output = out_blob
if is_test:
output_blobs = bn_output
else:
output_blobs = bn_output + [self.rm, self.riv,
net.NextScopedBlob('bn_saved_mean'),
net.NextScopedBlob('bn_saved_iv')]
net.SpatialBN([input_blob, self.scale,
self.bias, self.rm, self.riv],
output_blobs,
momentum=self.momentum,
is_test=is_test,
order=self.order)
if len(self.input_shape) == 1:
net.Squeeze(bn_output,
bn_output,
dims=[2, 3])
def add_train_ops(self, net):
self._add_ops(net, is_test=False)
def add_eval_ops(self, net):
self._add_ops(net, is_test=True)
def add_ops(self, net):
self.add_eval_ops(net)