pytorch/docs/source
Thomas J. Fan 6ff001c125 DOC Improve documentation for LayerNorm (#59178)
Summary:
Closes https://github.com/pytorch/pytorch/issues/51455

I think the current implementation is aggregating over the correct dimensions. The shape of `normalized_shape` is only used to determine the dimensions to aggregate over. The actual values of `normalized_shape` are used when `elementwise_affine=True` to initialize the weights and biases.

This PR updates the docstring to clarify how `normalized_shape` is used. Here is a short script comparing the implementations for tensorflow and pytorch:

```python
import torch
import torch.nn as nn

import tensorflow as tf
from tensorflow.keras.layers import LayerNormalization

rng = np.random.RandomState()
x = rng.randn(10, 20, 64, 64).astype(np.float32)
# slightly non-trival
x[:, :10, ...] = x[:, :10, ...] * 10 + 20
x[:, 10:, ...] = x[:, 10:, ...] * 30 - 100

# Tensorflow Layer norm
x_tf = tf.convert_to_tensor(x)
layer_norm_tf = LayerNormalization(axis=[-3, -2, -1], epsilon=1e-5)
output_tf = layer_norm_tf(x_tf)
output_tf_np = output_tf.numpy()

# PyTorch Layer norm
x_torch = torch.as_tensor(x)
layer_norm_torch = nn.LayerNorm([20, 64, 64], elementwise_affine=False)
output_torch = layer_norm_torch(x_torch)
output_torch_np = output_torch.detach().numpy()

# check tensorflow and pytorch
torch.testing.assert_allclose(output_tf_np, output_torch_np)

# manual comutation
manual_output = ((x_torch - x_torch.mean(dim=(-3, -2, -1), keepdims=True)) /
                 (x_torch.var(dim=(-3, -2, -1), keepdims=True, unbiased=False) + 1e-5).sqrt())

torch.testing.assert_allclose(output_torch, manual_output)
```

To get to the layer normalization as shown here:

<img width="157" alt="Screen Shot 2021-05-29 at 2 13 52 PM" src="https://user-images.githubusercontent.com/5402633/120080691-1e37f100-c088-11eb-9060-4f263e4cd093.png">

One needs to pass in `normalized_shape` with shape `x.dim() - 1` with the size of the channels and all spatial dimensions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/59178

Reviewed By: ejguan

Differential Revision: D28931877

Pulled By: jbschlosser

fbshipit-source-id: 193e05205b9085bb190c221428c96d2ca29f2a70
2021-06-07 14:34:10 -07:00
..
_static DOC Improve documentation for LayerNorm (#59178) 2021-06-07 14:34:10 -07:00
_templates Remove master documentation from being indexable by search engines (#58056) 2021-05-18 06:20:09 -07:00
community Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
elastic [torch/elastic] Update the rendezvous docs (#58160) 2021-05-12 16:54:28 -07:00
notes Add no-grad inference mode note (#58513) 2021-05-25 13:06:54 -07:00
rpc Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
scripts Add mish activation function (#58648) 2021-05-25 10:36:21 -07:00
__config__.rst Fix __config__ docs (#48557) 2020-11-29 23:57:06 -08:00
amp.rst Adds grid_sampler to autocast fp32 list for 1.9 (#58679) 2021-05-20 14:05:09 -07:00
autograd.rst Add no-grad inference mode note (#58513) 2021-05-25 13:06:54 -07:00
backends.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
benchmark_utils.rst Expand benchmark utils docs (#51664) 2021-02-04 00:22:41 -08:00
bottleneck.rst
checkpoint.rst
complex_numbers.rst Abladawood patch 1 (#58496) 2021-05-20 10:32:18 -07:00
conf.py Use proper Google Analytics id (#56578) 2021-05-04 13:23:16 -07:00
cpp_extension.rst
cpp_index.rst
cuda.rst breakup optim, cuda documentation (#55673) 2021-04-14 12:44:00 -07:00
cudnn_persistent_rnn.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
cudnn_rnn_determinism.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
data.rst [DataLoader][doc] Randomness for base_seed generator and NumPy seed (#56528) 2021-04-22 09:40:45 -07:00
ddp_comm_hooks.rst [Gradient Compression] Remove unnecessary warning on the rst file and the check on C++ version (#58170) 2021-05-12 14:15:10 -07:00
distributed.elastic.rst [1/n][torch/elastic] Move torchelastic docs *.rst (#148) 2021-05-04 00:57:56 -07:00
distributed.optim.rst [Reland] Update and expose ZeroRedundancyOptimizer docs (#53112) 2021-03-02 14:16:12 -08:00
distributed.rst Document monitored barrier (#58322) 2021-05-21 19:04:57 -07:00
distributions.rst Add sample validation for LKJCholesky.log_prob (#52763) 2021-02-25 16:12:29 -08:00
dlpack.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
docutils.conf
fft.rst Use autosummary on torch.fft, torch.linalg (#55748) 2021-04-13 12:02:36 -07:00
futures.rst Update docs to mention CUDA support for Future (#50048) 2021-05-11 08:26:33 -07:00
fx.rst s/foward/forward/g (#58497) 2021-05-19 11:42:42 -07:00
hub.rst Add a torch.hub.load_local() function that can load models from any local directory with a hubconf.py (#44204) 2020-09-21 14:17:21 -07:00
index.rst add torch.testing to docs (#57247) 2021-05-07 09:16:39 -07:00
jit_builtin_functions.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
jit_language_reference_v2.rst Fix hasattr support type (#57950) 2021-05-10 12:21:56 -07:00
jit_language_reference.rst add type annotations to torch.nn.modules.conv (#49564) 2021-01-15 11:16:11 -08:00
jit_python_reference.rst [JIT] improve documentation (#57991) 2021-05-19 11:47:32 -07:00
jit_unsupported.rst [JIT] Update docs for recently added features (#45232) 2020-09-28 18:17:42 -07:00
jit.rst Remove caption for Lang Reference (#56526) 2021-04-20 14:33:42 -07:00
linalg.rst Add torch.linalg.inv_ex without checking for errors by default (#58039) 2021-05-13 09:42:15 -07:00
math-quantizer-equation.png
mobile_optimizer.rst Mod lists to neutral+descriptive terms in caffe2/docs (#49803) 2020-12-23 11:37:11 -08:00
model_zoo.rst
multiprocessing.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
name_inference.rst Abladawood patch 1 (#58496) 2021-05-20 10:32:18 -07:00
named_tensor.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
nn.functional.rst Add mish activation function (#58648) 2021-05-25 10:36:21 -07:00
nn.init.rst
nn.rst DOC Adds register_module_full_backward_hook into docs (#58954) 2021-06-01 15:47:10 -07:00
onnx.rst [ONNX] Add hardsigmoid symbolic in opset 9 #49649 (#54193) 2021-04-07 14:28:31 -07:00
optim.rst To add single and chained learning schedulers to docs (#56705) 2021-04-23 09:36:00 -07:00
package.rst [package] Add an intern keyword (#57341) 2021-05-12 16:22:43 -07:00
pipeline.rst Add tutorials to pipeline docs. (#55209) 2021-04-05 20:01:00 -07:00
profiler.rst docs: fix profiler docstring (#55750) 2021-04-13 00:23:14 -07:00
quantization-support.rst [docs][quant] Add fx graph mode quant api doc (#55306) 2021-04-05 13:56:23 -07:00
quantization.rst Fix broken link to fx graph quant guide in quantization.rst (#56776) 2021-04-26 08:22:28 -07:00
random.rst
rpc.rst Add a disclaimer about limited CUDA support in RPC (#58023) 2021-05-12 00:11:22 -07:00
sparse.rst Add CSR (compressed sparse row) layout for sparse tensors (#50937) 2021-04-12 10:09:12 -07:00
special.rst Alias for i0 to special namespace (#59141) 2021-06-01 23:04:09 -07:00
storage.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
tensor_attributes.rst Remove legacy constructor calls from pytorch codebase. (#54142) 2021-04-11 15:45:17 -07:00
tensor_view.rst Conjugate View (#54987) 2021-06-04 14:12:41 -07:00
tensorboard.rst
tensors.rst Conjugate View (#54987) 2021-06-04 14:12:41 -07:00
testing.rst add torch.testing to docs (#57247) 2021-05-07 09:16:39 -07:00
torch.nn.intrinsic.qat.rst [quantization] Add some support for 3d operations (#50003) 2021-03-10 16:40:35 -08:00
torch.nn.intrinsic.quantized.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
torch.nn.intrinsic.rst [quantization] Add some support for 3d operations (#50003) 2021-03-10 16:40:35 -08:00
torch.nn.qat.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
torch.nn.quantized.dynamic.rst Forbid trailing whitespace (#53406) 2021-03-05 17:22:55 -08:00
torch.nn.quantized.rst [quant] add docs for embedding/embedding_bag (#51770) 2021-02-05 11:43:15 -08:00
torch.overrides.rst Add documentation for torch.overrides submodule. (#48170) 2020-11-30 11:25:31 -08:00
torch.quantization.rst Lint trailing newlines (#54737) 2021-03-30 13:09:52 -07:00
torch.rst Conjugate View (#54987) 2021-06-04 14:12:41 -07:00
type_info.rst