pytorch/caffe2/python/helpers/normalization.py
Yinghai Lu ef8f556212
[Caffe2] Changes done inside Facebook (#6378)
* fix unit test for sqrt op

From the error logging:

[idx, grad, grad_estimate] are:
[[ 146.            0.5           0.45776367]
 [ 147.            0.5           0.45776367]

The gradient == 0.5 is correct, which means the SqrtOp and its gradient is doing right job. (Because y = sqrt(x), loss = y^2/2 = x/2, and then d(loss)/dx = 1/2 = 0.5; )

The test failed because of numerical problem of grad_estimate (in unit test). It can be because the step_size is small, and float precision is not high (when there are multiple elements in the tensor, we do sum(y^2) to compute loss)

This diff
- increase the step size, and also move the test cases to be further away from 0 (where sqrt(x) is not well defined) to be safe :)
- also clean up, and merge the test case for inplace Vs. non-inplace

Tested with:

`CAFFE2_HYPOTHESIS_PROFILE=debug ai_bt caffe2/caffe2/python/operator_test:elementwise_ops_test -- "test_sqrt"`

* CompositeReader & CompositeReaderBuilder

A new type of reader gluing multiple readers together.

* Back out "Revert D7394363: [GanH]: Log D Trick for Cross Entropy with Sigmoid"

Original commit changeset: 9325a4356dbe

* [dai][WIP] convert params to int8 on ps before sending to trainer

Add float->uint8 conversion in addition to float->fp16 conversion in model_saver.

* [easy] improve unit test for sparse length sum ops

as desc.

#accept2ship

* Update GitHub upstream to 771fcb3455

* move sparse hash unique ops to OOS and add unit tests

- move the SparseHash version to OOS, since 'sparsehash' is already deps of caffe2 OOS: https://fburl.com/arssw4n1
- The 'SparseHash' engine is also being used in OOS, so the SparseHash version shall be in OOS to reduce confusion: https://fburl.com/o5ea7ah2

- fix the CUDA UniqueOp for the case when batch is empty.
- add unit test

* group_norm_op for caffe2

This is the cuda op for Group Normalization (GN): https://arxiv.org/abs/1803.08494

This code implements GN in one op that computes Y=gamma * (X-mu) / sigma + beta and also its gradients. It is expected to have minimal memory consumption (similar to the BN op), without creating new blobs if GN were implemented as several ops (e.g., reshape, norm_mean/std, affine_channel).

* Resubmit D7405233: disappeared in D7464958

OOS publish causes the op missing -- however, test was still there

* [c2] add sparse hash engine for cuda unique op

The SparseHash version of UniqueOp copy input tensor to CPU, and make use of sparse hash map to get unique output, and then copy back to GPU.

* [dper][gpu] enable unit testing gpu trainer for sparse nn

to debug the GPU trainer using mock data in unit test.

make it easier to develop GPU trainer for new models.

* Reuse Gloo context for Synchronize() calls

Previously we were creating (and leaking) the Gloo context on each call to Synchronize(). Now only run the common world op and create the barrier net once, then run the barrier net on each Synchronize() call. Since timeout is associated with the Gloo context, assert that the timeout is fixed instead of trying to handle the complexity of multiple timeouts (and associated contexts).

* [GanH/WGAN][1/n]: add FC param clipping

as titled

* [mobile] minimizing changes between caffe2_benchmark and speed_benchmark

* [GanH]: enable diagnose within model

avoid finding blob names but to directly enable inside the model

* Add `net_transformer_fun` option to DPM

This callback allows for various transformations to be made to the
model after gradient operators have been added. The immediate motivation for
this is to allow transformations such has "checkpoint-and-recompute" which
allow trading off memory for additional compute.

Adding several callbacks like this has made DPM's API less than ideal at this
stage. However, I could not find any reasonable alternative.

* [DT] [33/n] Compile flow task groups

task groups need to compiled in order to pickle the object in fblearner. However I also changed the Job's compile function as creating new object is not necessary.

* Initial commit for sparse_normalize vectorization and benchmark

* [GanH]: LB Calibration for JSD

as titled

* Tracing event in async executor

Adding event tracing through TRACE_EVENT macro in async executor

* [Resubmit] D7409751 Reseting book-keeping blobs when the reservoir is reset

D7409751 got lost in D7464958

* Visualizing realtime weights values

we want to visualize the weights values as optimizer is iterating. This diff supports to visual the weights at an assigned index.
Currently, we assume the blob to be 2 dimensional.

* [GanH][Easy]: Fix Homotopy Weighting

apparantely, there was a bug in homotopy weight (alpha, beta) update

* [c2] move sparse hash unique op out of oss

so that oss do not need to depend on google hash map.

* Get rid of std::round as it's not supported on Android

* Revert changes on setup.py

* Skip shaky test on Dataio

* fix
2018-04-10 21:11:43 -07:00

293 lines
9.5 KiB
Python

## @package normalization
# Module caffe2.python.helpers.normalization
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import scope
from caffe2.python.modeling.parameter_info import ParameterTags
from caffe2.proto import caffe2_pb2
from caffe2.python.modeling import initializers
def lrn(model, blob_in, blob_out, order="NCHW", use_cudnn=False, **kwargs):
"""LRN"""
dev = kwargs['device_option'] if 'device_option' in kwargs \
else scope.CurrentDeviceScope()
is_cpu = dev is None or dev.device_type == caffe2_pb2.CPU
if use_cudnn and (not is_cpu):
kwargs['engine'] = 'CUDNN'
blobs_out = blob_out
else:
blobs_out = [blob_out, "_" + blob_out + "_scale"]
lrn = model.net.LRN(
blob_in,
blobs_out,
order=order,
**kwargs
)
if use_cudnn and (not is_cpu):
return lrn
else:
return lrn[0]
def softmax(model, blob_in, blob_out=None, use_cudnn=False, **kwargs):
"""Softmax."""
if use_cudnn:
kwargs['engine'] = 'CUDNN'
if blob_out is not None:
return model.net.Softmax(blob_in, blob_out, **kwargs)
else:
return model.net.Softmax(blob_in, **kwargs)
def instance_norm(model, blob_in, blob_out, dim_in, order="NCHW", **kwargs):
blob_out = blob_out or model.net.NextName()
# Input: input, scale, bias
# Output: output, saved_mean, saved_inv_std
# scale: initialize with ones
# bias: initialize with zeros
def init_blob(value, suffix):
return model.param_init_net.ConstantFill(
[], blob_out + "_" + suffix, shape=[dim_in], value=value)
scale, bias = init_blob(1.0, "s"), init_blob(0.0, "b")
model.AddParameter(scale, ParameterTags.WEIGHT)
model.AddParameter(bias, ParameterTags.BIAS)
blob_outs = [blob_out, blob_out + "_sm", blob_out + "_siv"]
if 'is_test' in kwargs and kwargs['is_test']:
blob_outputs = model.net.InstanceNorm(
[blob_in, scale, bias], [blob_out],
order=order, **kwargs)
return blob_outputs
else:
blob_outputs = model.net.InstanceNorm(
[blob_in, scale, bias], blob_outs,
order=order, **kwargs)
# Return the output
return blob_outputs[0]
def spatial_bn(model, blob_in, blob_out, dim_in,
init_scale=1., init_bias=0.,
ScaleInitializer=None, BiasInitializer=None,
RunningMeanInitializer=None, RunningVarianceInitializer=None,
order="NCHW", **kwargs):
blob_out = blob_out or model.net.NextName()
# Input: input, scale, bias, est_mean, est_inv_var
# Output: output, running_mean, running_inv_var, saved_mean,
# saved_inv_var
# scale: initialize with init_scale (default 1.)
# bias: initialize with init_bias (default 0.)
# est mean: zero
# est var: ones
if model.init_params:
scale_init = ("ConstantFill", {'value': init_scale})
bias_init = ("ConstantFill", {'value': init_bias})
rm_init = ("ConstantFill", {'value': 0.0})
riv_init = ("ConstantFill", {'value': 1.0})
ScaleInitializer = initializers.update_initializer(
ScaleInitializer, scale_init, ("ConstantFill", {})
)
BiasInitializer = initializers.update_initializer(
BiasInitializer, bias_init, ("ConstantFill", {})
)
RunningMeanInitializer = initializers.update_initializer(
RunningMeanInitializer, rm_init, ("ConstantFill", {})
)
RunningVarianceInitializer = initializers.update_initializer(
RunningVarianceInitializer, riv_init, ("ConstantFill", {})
)
else:
ScaleInitializer = initializers.ExternalInitializer()
BiasInitializer = initializers.ExternalInitializer()
RunningMeanInitializer = initializers.ExternalInitializer()
RunningVarianceInitializer = initializers.ExternalInitializer()
scale = model.create_param(
param_name=blob_out + '_s',
shape=[dim_in],
initializer=ScaleInitializer,
tags=ParameterTags.WEIGHT
)
bias = model.create_param(
param_name=blob_out + '_b',
shape=[dim_in],
initializer=BiasInitializer,
tags=ParameterTags.BIAS
)
running_mean = model.create_param(
param_name=blob_out + '_rm',
shape=[dim_in],
initializer=RunningMeanInitializer,
tags=ParameterTags.COMPUTED_PARAM
)
running_inv_var = model.create_param(
param_name=blob_out + '_riv',
shape=[dim_in],
initializer=RunningVarianceInitializer,
tags=ParameterTags.COMPUTED_PARAM
)
blob_outs = [blob_out, running_mean, running_inv_var,
blob_out + "_sm", blob_out + "_siv"]
if 'is_test' in kwargs and kwargs['is_test']:
blob_outputs = model.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], [blob_out],
order=order, **kwargs)
return blob_outputs
else:
blob_outputs = model.net.SpatialBN(
[blob_in, scale, bias, blob_outs[1], blob_outs[2]], blob_outs,
order=order, **kwargs)
# Return the output
return blob_outputs[0]
def spatial_gn(model, blob_in, blob_out, dim_in,
init_scale=1., init_bias=0.,
ScaleInitializer=None, BiasInitializer=None,
RunningMeanInitializer=None, RunningVarianceInitializer=None,
order="NCHW", **kwargs):
'''
Group normalizes the input, cf. https://arxiv.org/abs/1803.08494.
'''
blob_out = blob_out or model.net.NextName()
# Input: input, scale, bias
# Output: output, group_mean, group_std
# scale: initialize with init_scale (default 1.)
# [recommendation: set init_scale = 0. in the last layer for each res block]
# bias: initialize with init_bias (default 0.)
if model.init_params:
scale_init = ("ConstantFill", {'value': init_scale})
bias_init = ("ConstantFill", {'value': init_bias})
ScaleInitializer = initializers.update_initializer(
ScaleInitializer, scale_init, ("ConstantFill", {})
)
BiasInitializer = initializers.update_initializer(
BiasInitializer, bias_init, ("ConstantFill", {})
)
else:
ScaleInitializer = initializers.ExternalInitializer()
BiasInitializer = initializers.ExternalInitializer()
scale = model.create_param(
param_name=blob_out + '_s',
shape=[dim_in],
initializer=ScaleInitializer,
tags=ParameterTags.WEIGHT
)
bias = model.create_param(
param_name=blob_out + '_b',
shape=[dim_in],
initializer=BiasInitializer,
tags=ParameterTags.BIAS
)
blob_outs = [blob_out,
blob_out + "_mean", blob_out + "_std"]
blob_outputs = model.net.GroupNorm(
[blob_in, scale, bias],
blob_outs,
**kwargs)
# Return the output
return blob_outputs[0]
def layer_norm(
model,
blob_in,
blob_out,
dim_in,
axis=1,
epsilon=1e-4,
initial_scale=1.0,
initial_bias=0.0,
):
'''
Layer normalizes the input, cf. https://arxiv.org/pdf/1607.06450.pdf.
Args:
blob_in: The input blob to layer normalize.
blob_out: The layer normalized output blob.
dim_in: The dimension of the scale and bias. For example, if blob_in is
a 2D design matrix and axis is 1, this would be the number of
columns.
axis: (optional) The axis to normalize. Typically the feature axis.
Defaults to 1.
epsilon: (optional) A small value used for numerical stability in
calculation. Defaults to 1e-4.
initial_scale: (optional) The initial value for the learned scale
parameter. Defaults to 1.0
initial_bias: (optional) The initial value for the learned bias
parameter of the layerwise standard deviation. Defaults to 0.0.
Returns:
A 3-tuple consisting of:
- The layer normalized input blob.
- The mean of the input blob across the given axis.
- The standard deviation of the input blob acress the given axis.
'''
# The LayerNorm operator only performs the layerwise z-shift, without
# scaling and shifting by the learned scale and bias parameters. We have
# to do that separately below.
normalized, mean, stdev = model.net.LayerNorm(
[blob_in],
[blob_out, blob_out + "_mean", blob_out + "_stdev"],
axis=axis,
epsilon=epsilon,
)
# The learned multiplicative scale or "gain".
scale = model.create_param(
param_name='{}_scale'.format(blob_out),
shape=[dim_in],
initializer=initializers.Initializer(
'ConstantFill',
value=initial_scale,
),
tags=ParameterTags.WEIGHT,
)
# The learned additive bias or "shift".
bias = model.create_param(
param_name='{}_bias'.format(blob_out),
shape=[dim_in],
initializer=initializers.Initializer(
'ConstantFill',
value=initial_bias,
),
tags=ParameterTags.BIAS,
)
scaled = model.net.Mul(
[normalized, scale],
['{}_scaled'.format(blob_out)],
broadcast=1,
axis=axis,
)
biased = model.net.Add(
[scaled, bias],
['{}_biased'.format(blob_out)],
broadcast=1,
axis=axis,
)
return biased, mean, stdev