update transformer code for layer_norm() API change

Summary: Quick fix for unit test broken by D6454290. This is my fault for approving while the tests covering the single callsite were broken.

Reviewed By: goldsborough

Differential Revision: D6466566

fbshipit-source-id: 2683be3d6bb184286e64fbde3e572946e39030c7
This commit is contained in:
James Cross 2017-12-01 19:59:20 -08:00 committed by Facebook Github Bot
parent 96cd3743f1
commit 2c190d2f05
2 changed files with 15 additions and 7 deletions

View File

@ -216,29 +216,37 @@ def layer_norm(
# The learned multiplicative scale or "gain". # The learned multiplicative scale or "gain".
scale = model.create_param( scale = model.create_param(
param_name='{}_scale'.format(blob_out), param_name='{}_scale'.format(blob_out),
shape=dim_in, shape=[dim_in],
initializer=initializers.Initializer('ConstantFill', value=initial_scale), initializer=initializers.Initializer(
tags=ParameterTags.WEIGHT 'ConstantFill',
value=initial_scale,
),
tags=ParameterTags.WEIGHT,
) )
# The learned additive bias or "shift". # The learned additive bias or "shift".
bias = model.create_param( bias = model.create_param(
param_name='{}_bias'.format(blob_out), param_name='{}_bias'.format(blob_out),
shape=dim_in, shape=[dim_in],
initializer=initializers.Initializer('ConstantFill', value=initial_bias), initializer=initializers.Initializer(
tags=ParameterTags.BIAS 'ConstantFill',
value=initial_bias,
),
tags=ParameterTags.BIAS,
) )
scaled = model.net.Mul( scaled = model.net.Mul(
[normalized, scale], [normalized, scale],
['{}_scaled'.format(blob_out)], ['{}_scaled'.format(blob_out)],
broadcast=1, broadcast=1,
axis=axis,
) )
biased = model.net.Add( biased = model.net.Add(
[scaled, bias], [scaled, bias],
['{}_biased'.format(blob_out)], ['{}_biased'.format(blob_out)],
broadcast=1, broadcast=1,
axis=axis,
) )
return biased, mean, stdev return biased, mean, stdev

View File

@ -155,7 +155,7 @@ class TestLayerNormOp(hu.HypothesisTestCase):
model, model,
'input', 'input',
'output', 'output',
dim_in=scale_dim, dim_in=X.shape[axis],
axis=axis, axis=axis,
epsilon=1e-4, epsilon=1e-4,
) )