mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
update transformer code for layer_norm() API change
Summary: Quick fix for unit test broken by D6454290. This is my fault for approving while the tests covering the single callsite were broken. Reviewed By: goldsborough Differential Revision: D6466566 fbshipit-source-id: 2683be3d6bb184286e64fbde3e572946e39030c7
This commit is contained in:
parent
96cd3743f1
commit
2c190d2f05
|
|
@ -216,29 +216,37 @@ def layer_norm(
|
||||||
# The learned multiplicative scale or "gain".
|
# The learned multiplicative scale or "gain".
|
||||||
scale = model.create_param(
|
scale = model.create_param(
|
||||||
param_name='{}_scale'.format(blob_out),
|
param_name='{}_scale'.format(blob_out),
|
||||||
shape=dim_in,
|
shape=[dim_in],
|
||||||
initializer=initializers.Initializer('ConstantFill', value=initial_scale),
|
initializer=initializers.Initializer(
|
||||||
tags=ParameterTags.WEIGHT
|
'ConstantFill',
|
||||||
|
value=initial_scale,
|
||||||
|
),
|
||||||
|
tags=ParameterTags.WEIGHT,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The learned additive bias or "shift".
|
# The learned additive bias or "shift".
|
||||||
bias = model.create_param(
|
bias = model.create_param(
|
||||||
param_name='{}_bias'.format(blob_out),
|
param_name='{}_bias'.format(blob_out),
|
||||||
shape=dim_in,
|
shape=[dim_in],
|
||||||
initializer=initializers.Initializer('ConstantFill', value=initial_bias),
|
initializer=initializers.Initializer(
|
||||||
tags=ParameterTags.BIAS
|
'ConstantFill',
|
||||||
|
value=initial_bias,
|
||||||
|
),
|
||||||
|
tags=ParameterTags.BIAS,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled = model.net.Mul(
|
scaled = model.net.Mul(
|
||||||
[normalized, scale],
|
[normalized, scale],
|
||||||
['{}_scaled'.format(blob_out)],
|
['{}_scaled'.format(blob_out)],
|
||||||
broadcast=1,
|
broadcast=1,
|
||||||
|
axis=axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
biased = model.net.Add(
|
biased = model.net.Add(
|
||||||
[scaled, bias],
|
[scaled, bias],
|
||||||
['{}_biased'.format(blob_out)],
|
['{}_biased'.format(blob_out)],
|
||||||
broadcast=1,
|
broadcast=1,
|
||||||
|
axis=axis,
|
||||||
)
|
)
|
||||||
|
|
||||||
return biased, mean, stdev
|
return biased, mean, stdev
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ class TestLayerNormOp(hu.HypothesisTestCase):
|
||||||
model,
|
model,
|
||||||
'input',
|
'input',
|
||||||
'output',
|
'output',
|
||||||
dim_in=scale_dim,
|
dim_in=X.shape[axis],
|
||||||
axis=axis,
|
axis=axis,
|
||||||
epsilon=1e-4,
|
epsilon=1e-4,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user