diff --git a/caffe2/python/helpers/normalization.py b/caffe2/python/helpers/normalization.py index 7845fe5d42b..168d785fbec 100644 --- a/caffe2/python/helpers/normalization.py +++ b/caffe2/python/helpers/normalization.py @@ -216,29 +216,37 @@ def layer_norm( # The learned multiplicative scale or "gain". scale = model.create_param( param_name='{}_scale'.format(blob_out), - shape=dim_in, - initializer=initializers.Initializer('ConstantFill', value=initial_scale), - tags=ParameterTags.WEIGHT + shape=[dim_in], + initializer=initializers.Initializer( + 'ConstantFill', + value=initial_scale, + ), + tags=ParameterTags.WEIGHT, ) # The learned additive bias or "shift". bias = model.create_param( param_name='{}_bias'.format(blob_out), - shape=dim_in, - initializer=initializers.Initializer('ConstantFill', value=initial_bias), - tags=ParameterTags.BIAS + shape=[dim_in], + initializer=initializers.Initializer( + 'ConstantFill', + value=initial_bias, + ), + tags=ParameterTags.BIAS, ) scaled = model.net.Mul( [normalized, scale], ['{}_scaled'.format(blob_out)], broadcast=1, + axis=axis, ) biased = model.net.Add( [scaled, bias], ['{}_biased'.format(blob_out)], broadcast=1, + axis=axis, ) return biased, mean, stdev diff --git a/caffe2/python/operator_test/layer_norm_op_test.py b/caffe2/python/operator_test/layer_norm_op_test.py index ba7d5159c38..2b443f1b3c0 100644 --- a/caffe2/python/operator_test/layer_norm_op_test.py +++ b/caffe2/python/operator_test/layer_norm_op_test.py @@ -155,7 +155,7 @@ class TestLayerNormOp(hu.HypothesisTestCase): model, 'input', 'output', - dim_in=scale_dim, + dim_in=X.shape[axis], axis=axis, epsilon=1e-4, )