Summary:
# context
* enable the `_get_user_embeddings` function
* run failed at P1610151892
```
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression u22 <= 0 (unhinted: u22 <= 0). (Size-like symbols: u22)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/38472faba4e3e6c1/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 1692, in native_layer_norm_backward
if M <= 0 or N <= 0:
```
```
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
)
```
# changes
* use guard_size_oblivious since the new_zeros return is kind of optimization, shouldn't impact the correctness of the follow up code logic.
* the size `ret[i][j]` could be zero, so the change in V1 isn't valid
* for more details: [post](https://fb.workplace.com/groups/6829516587176185/permalink/8003616173099548/)
```
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
```
# past
* found `u22` was introduced at
```
def _wait_impl(self) -> List[List[int]]:
# Can not use is_torchdynamo_compiling(), as every such condition should be independent for compilation with graph breaks.
if isinstance(self._splits_awaitable, dist.Work):
self._splits_awaitable.wait()
ret = self._output_tensor.view(self.num_workers, -1).T.tolist() # <------ u22 introduced here
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
for i in range(len(ret)):
for j in range(len(ret[i])):
torch._check_is_size(ret[i][j]) # <---------- my question: why the _check_is_size isn't enough??
torch._check(ret[i][j] > 0) # <------ added by diff V1
```
Test Plan:
# run command
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2 2>&1 | tee -a `tagT`.`tagH`.log
```
# results
* before
**without enabling `_get_user_embeddings`**
[14 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp2eNI7p/failures_and_restarts.html)
log: P1610151892
{F1889387940}
* V1
enable `_get_user_embeddings`
with `torch._check(ret[i][j] > 0)`
[13 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp6J1iY9/failures_and_restarts.html)
{F1889388378}
* V2
enable `_get_user_embeddings`
with `if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):`
[tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpFhZZyC/index.html)
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
Differential Revision: D63424929
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136798
Approved by: https://github.com/ezyang
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.
Test included in this PR reproduces the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134568
Approved by: https://github.com/zou3519
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.
Test Plan: CI
Differential Revision: D62278378
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
## Description
Create decomposition of _unsafe_index_put (non-core aten) that turns it into index_put (core aten)
## Testing
Phi3 mini + LoRA model successfully passed `to_edge` after failing due to a non-core aten `unsafe_index_put` getting introduced in a decomposition during joint graph calculations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133365
Approved by: https://github.com/pianpwk
Summary:
# context
* when running an IG FM training with PT2 we found there are a few graph break due to torch.diff call in [jagged_tensor.py](https://fburl.com/code/cwssxabc)
```
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
```
* look into the failure, we found the TORCH_CHECK in diff should be TORCH_SYM_CHECK
* slice_forward error: df3d7729e, [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html)
```
RestartAnalysis
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0). (Size-like symbols: u38, u37)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward
if end_val < 0:
```
* after this diff: [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpAhv2Sh/failures_and_restarts.html)
Test Plan:
# command
* run model
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2
```
* generate tlparse
```
tlparse `ls -t /var/tmp/tt/* | head -1`
```
Reviewed By: ezyang
Differential Revision: D56339251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133740
Approved by: https://github.com/ezyang
# Summary
Changes the stance of SDPA on what to do for fully masked out rows
## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963
These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617
Can be paraphrased as follows:
When passing in fully masked out rows, attention becomes ambiguous. We have two main options:
1. Uniformly attend to all values:
```python
scores[masked_out_rows] = 1 / len(row)
out[masked_out_rows] = 1 / len(row) * value
```
2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
```python
output[fully_masked_rows] = NaN
```
We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
[ nan, nan, nan, nan]])
```
Those pesky NaNs are back!
## Why do we see NaNs today?
The core of the problem revolves around using softmax function in sdpa:
```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```
## Quick Aside: Masking in Attention
Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.
We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.
## Alternative Approaches
If we use a very large negative number instead of -inf:
```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564, 0.1613, -0.0486],
[ 0.0000, 0.0000, 0.0000, 0.0000]])
```
This would bring us back into a better state.
## A Third Option
We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.
This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```
**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.
## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel
_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.
Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?
The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`
Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`
If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this
```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.
## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
This PR makes sure all current tests in the sparsity export test suite pass. Note that there will probably be anecdotal cases that need fixing after this, but the general idea of preserving sparsity metadata has been completed.
Fixes: https://github.com/pytorch/pytorch/issues/117188
```
$ PYTORCH_TEST_WITH_DYNAMO=0 python test/export/test_sparse.py ........................................................................................................................................................
----------------------------------------------------------------------
Ran 152 tests
OK
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132690
Approved by: https://github.com/ezyang
`aten._to_copy` can receive a python number as input. This occurs in
torch.compile support for vmap (see #130188). Previously, this would
raise an assertion error. This PR changes it so that if we see a python
number, we call torch.scalar_tensor on it first (h/t @bdhirsh).
Fixes#130362Fixes#130188
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130381
Approved by: https://github.com/Chillee
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.
After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.
But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.
The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.
Fixes https://github.com/pytorch/pytorch/issues/127396
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.
After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.
But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.
The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.
Fixes https://github.com/pytorch/pytorch/issues/127396
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
Implements forward automatic differentiation support for miopen_batch_norm as well as unskips the associated unit tests. Also fixes a class of functorch related unit tests that fail due to failing a contiguous tensor assertion in BatchNorm_miopen.cpp. Solution was to just limit tensors to miopen_batch_norm that have at least 3 dimensions. The exact restriction already existed in the cudnn path and is why the tests in question only failed on ROCm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125069
Approved by: https://github.com/jeffdaily, https://github.com/andrewor14
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
# Motivation
In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.
### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):
Input tensors:
```
AAAA BBB CC
AAAA BBB
BBB
```
Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```
### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):
Input tensors:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Reduce-scatter-copy-in first pad:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```
The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.
# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:
```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```
This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.
## Requirements on input
1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
**Summary:**
This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:
```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
aten.native_batch_norm ->
aten._native_batch_norm_legit (export only) ->
_batch_norm_legit_cpu/cuda (kernels, export only) ->
_batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```
Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.
Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:
```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```
The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:
```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```
Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.
Test Plan: `OpInfo` tests for `batch_norm_with_update`.
Reviewers: albanD, bdhirsh
Subscribers: albanD, bdhirsh, supriyar
Tasks: https://github.com/pytorch/pytorch/issues/111384
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD