Commit Graph

391 Commits

Author SHA1 Message Date
Huanyu He
bae8d5853e [TorchRec][PT2 compile] enable dynamo in _get_user_embeddings (#136798)
Summary:
# context
* enable the `_get_user_embeddings` function
* run failed at P1610151892
```
  torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
  GuardOnDataDependentSymNode: Could not guard on data-dependent expression u22 <= 0 (unhinted: u22 <= 0).  (Size-like symbols: u22)

  ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
  Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

  Potential framework code culprit (scroll up for full backtrace):
    File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/38472faba4e3e6c1/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 1692, in native_layer_norm_backward
      if M <= 0 or N <= 0:
```
```
    N = prod(inner_dims)  # type: ignore[arg-type]
    M = prod(outer_dims)  # type: ignore[arg-type]
    if M <= 0 or N <= 0:
        return (
            input.new_zeros(input_shape) if output_mask[0] else None,
            input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
            input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
        )
```
# changes
* use guard_size_oblivious since the new_zeros return is kind of optimization, shouldn't impact the correctness of the follow up code logic.
* the size `ret[i][j]` could be zero, so the change in V1 isn't valid
* for more details: [post](https://fb.workplace.com/groups/6829516587176185/permalink/8003616173099548/)
```
    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
    if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
```

# past
* found `u22` was introduced at
```
    def _wait_impl(self) -> List[List[int]]:
        # Can not use is_torchdynamo_compiling(), as every such condition should be independent for compilation with graph breaks.
        if isinstance(self._splits_awaitable, dist.Work):
            self._splits_awaitable.wait()

        ret = self._output_tensor.view(self.num_workers, -1).T.tolist()  # <------ u22 introduced here

        if not torch.jit.is_scripting() and is_torchdynamo_compiling():
            for i in range(len(ret)):
                for j in range(len(ret[i])):
                    torch._check_is_size(ret[i][j])   # <----------  my question: why the _check_is_size isn't enough??
                    torch._check(ret[i][j] > 0)   # <------ added by diff V1
```

Test Plan:
# run command
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2 2>&1 | tee -a `tagT`.`tagH`.log
```

# results
* before
**without enabling `_get_user_embeddings`**
[14 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp2eNI7p/failures_and_restarts.html)
log: P1610151892
{F1889387940}
* V1
enable `_get_user_embeddings`
with `torch._check(ret[i][j] > 0)`
[13 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp6J1iY9/failures_and_restarts.html)
{F1889388378}
* V2
enable `_get_user_embeddings`
with `if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):`
[tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpFhZZyC/index.html)
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):

Differential Revision: D63424929

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136798
Approved by: https://github.com/ezyang
2024-10-09 17:19:45 +00:00
niklasz
3f457ee1f6 Fix AOT Graph capture not propagating non_blocking copy parameter to … (#136513)
…inductor codegen.

Fixes #136260

**Note**: this is my first code contribution to torch so please let me know if there's anything I need to fix/some other convention I should follow.

Regarding the bug, re-running the issue's reproduction code:
```
import torch

def fn(x):
    return x.to(device="cuda", non_blocking=True)

inp = torch.randn(3, 4)

torch.compile(fn)(inp)
```

We now have the non_blocking being passed on to codegen properly:

```
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  ===== pre insert_deferred_runtime_asserts __compiled_fn_1 =====
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4]"):
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         to: "f32[3, 4]" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.393000 679839 torch/fx/passes/runtime_assert.py:114] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code] TRACED GRAPH
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]     def forward(self, L_x_: "f32[3, 4][4, 1]cpu"):
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         l_x_ = L_x_
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         to: "f32[3, 4][4, 1]cuda:0" = l_x_.to(device = 'cuda', non_blocking = True);  l_x_ = None
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]         return (to,)
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.394000 679839 torch/_dynamo/output_graph.py:1340] [0/0] [__graph_code]
V0922 20:33:25.404000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:114] [0/0] [__aot_graphs] aot_config id: 0, fw_metadata=ViewAndMutationMeta(input_info=[InputAliasInfo(is_leaf=True, mutates_data=False, mutates_metadata=False, mutations_hidden_from_autograd=True, mutations_under_no_grad_or_inference_mode=False, mutation_inductor_storage_resize=False, mutates_storage_metadata=False, requires_grad=False, keep_input_mutations=True)], output_info=[OutputAliasInfo(output_type=<OutputType.non_alias: 1>, raw_type=<class 'torch._subclasses.functional_tensor.FunctionalTensor'>, base_idx=None, dynamic_dims=set(), requires_grad=False, functional_tensor=None)], num_intermediate_bases=0, keep_input_mutations=True, traced_tangents=[], subclass_inp_meta=[0], subclass_fw_graph_out_meta=[0], subclass_tangent_meta=[], is_train=False, traced_tangent_metas=None, num_symints_saved_for_bw=None, grad_enabled_mutation=None, deterministic=None, static_input_indices=[], tokens={}, indices_of_inputs_that_requires_grad_with_mutations_in_bw=[], bw_donated_idxs=None, num_backward_tokens=0),subclass_metadata=None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs] TRACED GRAPH
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  ===== Forward graph 0 =====
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]  /home/niklasz/Desktop/pytorch/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]     def forward(self, arg0_1: "f32[3, 4][4, 1]cpu"):
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]          # File: /home/niklasz/Desktop/pytorch/temp/reproduction.py:4 in fn, code: return x.to(device="cuda", non_blocking=True)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         device_put: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.device_put.default(arg0_1, device(type='cuda', index=0), True);  arg0_1 = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         convert_element_type: "f32[3, 4][4, 1]cuda:0" = torch.ops.prims.convert_element_type.default(device_put, torch.float32);  device_put = None
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]         return (convert_element_type,)
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
I0922 20:33:25.409000 679839 torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py:204] [0/0] [__aot_graphs]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1134] [0/0] [__output_code] Output code written to: /tmp/torchinductor_niklasz/ha/chaai264g6ribfw3q2qhl6ayjtaqaavku5wivxtzw4nabgd6htsv.py
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] Output code:
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] # AOT ID: ['0_inference']
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import torch
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import math
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import random
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import os
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] import tempfile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from math import inf, nan
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch import device, empty_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] aten = torch.ops.aten
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] _quantized = torch.ops._quantized
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile = AsyncCompile()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] async_compile.wait(globals())
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] del async_compile
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def call(args):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1, = args
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     args.clear()
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     assert_size_stride(arg0_1, (3, 4), (4, 1))
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         torch.cuda.set_device(0)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0 = empty_strided_cuda((3, 4), (4, 1), torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         buf0.copy_(arg0_1, True)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]         del arg0_1
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return (buf0, )
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     arg0_1 = rand_strided((3, 4), (4, 1), device='cpu', dtype=torch.float32)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     fn = lambda: call([arg0_1])
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code] if __name__ == "__main__":
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0922 20:33:25.983000 679839 torch/_inductor/codecache.py:1135] [0/0] [__output_code]
```
See above line `buf0.copy_(arg0_1, True)`. Specific log setting used: `export TORCH_LOGS="graph_code,aot_graphs,output_code"`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136513
Approved by: https://github.com/eellison
2024-10-01 00:32:47 +00:00
IvanKobzarev
370c1c4297 [aotd] Fix rrelu compilation (#136008)
Issues:
https://github.com/pytorch/pytorch/issues/135083
https://github.com/pytorch/pytorch/issues/120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA.
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136008
Approved by: https://github.com/bdhirsh
2024-09-25 11:26:19 +00:00
Isuru Fernando
f276da7f98 Remove prims.slice_in_dim and prims.slice (#136150)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136150
Approved by: https://github.com/ezyang
2024-09-23 01:27:22 +00:00
Isuru Fernando
0c936c3ecb Add decomps for max_unpool (#133146)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133146
Approved by: https://github.com/amjames, https://github.com/eellison
2024-09-20 21:35:25 +00:00
Jan Wieczorek
908a5689eb Return unsafe_view instead of view from matmul when folding occurs (#134568)
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.

Test included in this PR reproduces the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134568
Approved by: https://github.com/zou3519
2024-09-19 11:52:16 +00:00
Isuru Fernando
dab7d646d5 Use a better decomposition for split_with_sizes (#135728)
This decomposition has less checks and improves the performance
of torch.compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135728
Approved by: https://github.com/ezyang
2024-09-12 16:38:51 +00:00
Sidney Tsang
5d964a5eb7 [Export] Fix SDPA decomposition (#135297)
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.

Test Plan: CI

Differential Revision: D62278378

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
2024-09-11 20:21:59 +00:00
Bob Ren
ea89f01281 Remove unused comment (#135034)
As part of my rampup I've been reading through some of @ezyang's diffs. I noticed in https://github.com/pytorch/pytorch/pull/133439 there was a comment that he forgot to remove. This diff removes that comment.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135034
Approved by: https://github.com/albanD
2024-09-04 02:32:26 +00:00
Edward Z. Yang
bdfc1d3987 Remove unnecessary expect_true in split_with_sizes (#133439)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133439
Approved by: https://github.com/albanD
2024-08-27 01:34:00 +00:00
Jack Zhang
773a782249 Decompose _unsafe_index_put into index_put (#133365)
## Description
Create decomposition of _unsafe_index_put (non-core aten) that turns it into index_put (core aten)

## Testing
Phi3 mini + LoRA model successfully passed `to_edge` after failing due to a non-core aten `unsafe_index_put` getting introduced in a decomposition during joint graph calculations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133365
Approved by: https://github.com/pianpwk
2024-08-19 18:07:23 +00:00
Huanyu He
d5f6d68d68 [PT2] Resolve PT2 compatility issue in slice and diff (#133740)
Summary:
# context
* when running an IG FM training with PT2 we found there are a few graph break due to torch.diff call in [jagged_tensor.py](https://fburl.com/code/cwssxabc)
```
_length: List[int] = (
    _length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
    if variable_stride_per_key
    else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
```
* look into the failure, we found the TORCH_CHECK in diff should be TORCH_SYM_CHECK
* slice_forward error: df3d7729e, [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html)
```
RestartAnalysis
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time.  You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs.  Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0).  (Size-like symbols: u38, u37)

ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.

Potential framework code culprit (scroll up for full backtrace):
  File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward
    if end_val < 0:
```
* after this diff: [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpAhv2Sh/failures_and_restarts.html)

Test Plan:
# command
* run model
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2
```
* generate tlparse
```
tlparse `ls -t /var/tmp/tt/* | head -1`
```

Reviewed By: ezyang

Differential Revision: D56339251

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133740
Approved by: https://github.com/ezyang
2024-08-17 06:07:21 +00:00
drisspg
1434e0b121 Add a private _safe_softmax (#131060)
# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131060
Approved by: https://github.com/jbschlosser
2024-08-08 23:09:38 +00:00
Aart Bik
2f908ffa4a [traced-graph][sparse] sparsity propagation for all current tests (#132690)
This PR makes sure all current tests in the sparsity export test suite pass. Note that there will probably be anecdotal cases that need fixing after this, but the general idea of preserving sparsity metadata has been completed.

Fixes: https://github.com/pytorch/pytorch/issues/117188

```
$ PYTORCH_TEST_WITH_DYNAMO=0 python test/export/test_sparse.py ........................................................................................................................................................
 ----------------------------------------------------------------------
Ran 152 tests
OK
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132690
Approved by: https://github.com/ezyang
2024-08-06 21:18:13 +00:00
Xuehai Pan
e74ba1b34a [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
2024-07-31 21:18:11 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Isuru Fernando
43a6d20883 Add decomposition for reflection_pad{1,2,3}d_backward (#130299)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130299
Approved by: https://github.com/lezcano
ghstack dependencies: #130130
2024-07-17 21:56:00 +00:00
rzou
b38de2f9e2 [decomps] Fix aten._to_copy decomp (#130381)
`aten._to_copy` can receive a python number as input. This occurs in
torch.compile support for vmap (see #130188). Previously, this would
raise an assertion error. This PR changes it so that if we see a python
number, we call torch.scalar_tensor on it first (h/t @bdhirsh).

Fixes #130362

Fixes #130188

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130381
Approved by: https://github.com/Chillee
2024-07-10 14:34:28 +00:00
Isuru Fernando
c12a4f2e65 Add decomposition for slice_scatter (#123744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123744
Approved by: https://github.com/peterbell10
2024-06-28 17:02:10 +00:00
Isuru Fernando
e6bfa2958b Add aten._unsafe_masked_index (#116491)
To generate masked indexing operations that would generate
masked loads in triton code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491
Approved by: https://github.com/lezcano, https://github.com/peterbell10
2024-06-25 02:45:02 +00:00
Nikita Shulga
e47603a549 Fix weight_norm decomposition behavior (#128956)
By upcasting norm to float32 to align with CUDA and CPU behaviors
e6d4451ae8/aten/src/ATen/native/WeightNorm.cpp (L56-L59)

Discovered this when started running OpInfo tests, see https://github.com/pytorch/pytorch/actions/runs/9552858711/job/26332062502#step:20:1060
```
  File "/var/lib/jenkins/workspace/test/test_decomp.py", line 185, in op_assert_ref
    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
AssertionError: 1 Operation:  aten._weight_norm_interface.default
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128956
Approved by: https://github.com/albanD
ghstack dependencies: #128955
2024-06-18 21:24:12 +00:00
Nikita Shulga
44483972bd [EZ] Keep weight_norm var name aligned (#128955)
To keep it aligned with
e6d4451ae8/aten/src/ATen/native/native_functions.yaml (L6484)
I.e.  `x`->`v`, `y`->`g`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128955
Approved by: https://github.com/albanD, https://github.com/Skylion007
2024-06-18 18:40:59 +00:00
Edward Z. Yang
2229884102 Introduce int_oo (#127693)
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.

After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.

But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.

The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.

Fixes https://github.com/pytorch/pytorch/issues/127396

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
2024-06-13 04:08:20 +00:00
PyTorch MergeBot
5d8c7f39d4 Revert "Introduce int_oo (#127693)"
This reverts commit 9cab5987bd.

Reverted https://github.com/pytorch/pytorch/pull/127693 on behalf of https://github.com/clee2000 due to sorry executorch CI is a bit weird regarding pins, I'll make a chat with mergen with the choices of what to do and how it'll affect executorch CI, reverting for now to prevent more divergences in the meantime ([comment](https://github.com/pytorch/pytorch/pull/127693#issuecomment-2161775400))
2024-06-11 23:36:08 +00:00
Edward Z. Yang
9cab5987bd Introduce int_oo (#127693)
In a previous life, we used sympy.oo to represent the lower/upper bounds of integer ranges. Later, we changed this to be sys.maxsize - 1 for a few reasons: (1) sometimes we do tests on a value being exactly sys.maxsize, and we wanted to avoid a data dependent guard in this case, (2) sympy.oo corresponds to floating point infinity, so you get incorrect types for value ranges with oo, and (3) you can do slightly better reasoning if you assume that input sizes fall within representable 64-bit integer range.

After working in the sys.maxsize regime for a bit, I've concluded that this was actually a bad idea. Specifically, the problem is that you end up with sys.maxsize in your upper bound, and then whenever you do any sort of size-increasing computation like size * 2, you end up with 2 * sys.maxsize, and you end up doing a ton of arbitrary precision int computation that is totally unnecessary. A symbolic bound is better.

But especially after #126905, we can't go back to using sympy.oo, because that advertises that it's not an integer, and now your ValueRanges is typed incorrectly. So what do we do? We define a new numeric constant `int_oo`, which is like `sympy.oo` but it advertises `is_integer`. **test/test_sympy_utils.py** describes some basic properties of the number, and **torch/utils/_sympy/numbers.py** has the actual implementation.

The rest of the changes of the PR are working out the implications of this change. I'll give more commentary as inline comments.

Fixes https://github.com/pytorch/pytorch/issues/127396

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127693
Approved by: https://github.com/lezcano
ghstack dependencies: #126905
2024-06-10 19:09:53 +00:00
Aaron Orenstein
dcfa7702c3 Flip default value for mypy disallow_untyped_defs [1/11] (#127838)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127838
Approved by: https://github.com/oulgen
2024-06-08 18:16:33 +00:00
Aaron Gokaslan
12c4a2c297 [BE]: Apply PLR1736 fixes (unnecessary index lookup) (#127716)
Applies the PLR1736 preview rule with some more autofixes to cut down on unnecessary accesses. Added a noqa since that test actually testing the dunder method.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127716
Approved by: https://github.com/ezyang
2024-06-03 17:22:13 +00:00
PyTorch MergeBot
d1fad416a8 Revert "Add aten._unsafe_masked_index (#116491)"
This reverts commit f03f8bc901.

Reverted https://github.com/pytorch/pytorch/pull/116491 on behalf of https://github.com/PaliC due to breaking onnx tests ([comment](https://github.com/pytorch/pytorch/pull/116491#issuecomment-2145557724))
2024-06-03 15:51:50 +00:00
Isuru Fernando
f03f8bc901 Add aten._unsafe_masked_index (#116491)
To generate masked indexing operations that would generate
masked loads in triton code

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116491
Approved by: https://github.com/lezcano, https://github.com/peterbell10
2024-06-03 14:44:03 +00:00
Peter Bell
39de62845a [decomp] Fix default values missing from inplace rrelu decomposition (#126978)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126978
Approved by: https://github.com/lezcano
2024-05-26 23:49:40 +00:00
Andres Lugo-Reyes
38b8b614a2 [ROCm] Implement forward AD for miopen_batch_norm (#125069)
Implements forward automatic differentiation support for miopen_batch_norm as well as unskips the associated unit tests. Also fixes a class of functorch related unit tests that fail due to failing a contiguous tensor assertion in BatchNorm_miopen.cpp. Solution was to just limit tensors to miopen_batch_norm that have at least 3 dimensions. The exact restriction already existed in the cudnn path and is why the tests in question only failed on ROCm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125069
Approved by: https://github.com/jeffdaily, https://github.com/andrewor14
2024-05-14 19:09:50 +00:00
Edward Z. Yang
4731130ea8 Add a code comment about torch._check_is_size in tensor_split (#125292)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125292
Approved by: https://github.com/albanD
2024-05-02 02:25:38 +00:00
Aaron Orenstein
a8574a9719 Fix global flake8 issues (#124771)
Prior to this `lintrunner --all-files --take FLAKE8` failed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124771
Approved by: https://github.com/Skylion007
ghstack dependencies: #124428
2024-04-26 15:35:53 +00:00
PyTorch MergeBot
1ac60484c1 Revert "Fix global flake8 issues (#124771)"
This reverts commit f01275934b.

Reverted https://github.com/pytorch/pytorch/pull/124771 on behalf of https://github.com/jeanschmidt due to Unfortunately, I needed to revert #123735 and this one depends on it. So please check if there are no merge conflicts or breakages and feel free to merge this PR again ([comment](https://github.com/pytorch/pytorch/pull/124428#issuecomment-2078699836))
2024-04-26 06:15:17 +00:00
Aaron Orenstein
f01275934b Fix global flake8 issues (#124771)
Prior to this `lintrunner --all-files --take FLAKE8` failed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124771
Approved by: https://github.com/Skylion007
ghstack dependencies: #124428
2024-04-25 14:25:00 +00:00
Peter Bell
58806d6531 [decomp] Remove dead device_hint function (#124849)
The only use of this function is in `_to_copy` but the result is never used,
so this is just dead code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124849
Approved by: https://github.com/lezcano
2024-04-25 11:25:51 +00:00
Isuru Fernando
edcd968b51 Add out wrappers to some decompositions (#115437)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115437
Approved by: https://github.com/lezcano
2024-04-23 06:26:11 +00:00
vfdev-5
6330acae76 Refactored implementation for upsample_nearest decompostions (#122783)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122783
Approved by: https://github.com/peterbell10
2024-04-17 23:05:40 +00:00
Edward Z. Yang
60d7fbe89a Register matmul out variant so it is used (#122979)
Fixes https://github.com/pytorch/pytorch/issues/122774

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122979
Approved by: https://github.com/Chillee, https://github.com/Skylion007
2024-04-09 22:21:37 +00:00
Andrew M. James
bde1a93bc4 Add lowering for resize, decomp for resize_as. (#122317)
This has been split off from #121354 as the inplace version of these
methods prove to be rather tricky.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122317
Approved by: https://github.com/peterbell10, https://github.com/lezcano
2024-04-03 17:47:29 +00:00
vfdev-5
38946bff51 Added DispatchKey.CompositeImplicitAutograd to all upsample_nearest*.default decompositions (#122782)
Related to https://github.com/pytorch/pytorch/pull/117632#issuecomment-2021321172
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122782
Approved by: https://github.com/ezyang
2024-03-29 13:55:25 +00:00
vfdev-5
b524a404e0 Fixed support for uint8 in upsample bicubic2d decomposition (#120411)
Superseeds https://github.com/pytorch/pytorch/pull/104248

Description:
- Fixed support for uint8 for upsample bicubic2d decomposition (on `main` results are wrong, so we can tolerate the slowdown)
- Added missing clamp(0, 1) for xscale and yscale
  - slowdown for f32 on cpu. PR on nodes fusion on CPU: https://github.com/pytorch/pytorch/pull/120077 can help for upsampling cases with align corners = true
  - the slowdown mainly due to the added clamp op and also partially reduced when using torch.stack in weights computation on cpu.
- Removed lowering implementation

Benchmarks:
```
[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cpu --------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                   |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |        613.029 (+-1.590)        |         5477.608 (+-9.027)         |           3060.314 (+-12.368)           |     0.559 (+-0.000)      |          608.735 (+-6.336)
      Input (1, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |        610.176 (+-1.428)        |        5718.503 (+-11.203)         |           3424.022 (+-12.836)           |     0.599 (+-0.000)      |          604.781 (+-6.229)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        325.001 (+-0.840)        |        6183.029 (+-10.893)         |            3275.032 (+-7.625)           |     0.530 (+-0.000)      |          325.693 (+-1.067)
      Input (1, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        325.855 (+-1.108)        |        6391.394 (+-11.552)         |            3533.410 (+-7.666)           |     0.553 (+-0.000)      |          325.838 (+-1.457)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       2521.533 (+-14.857)       |        5025.217 (+-13.415)         |            2814.304 (+-6.742)           |     0.560 (+-0.000)      |         2520.308 (+-10.796)
      Input (1, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       2531.204 (+-12.534)       |        5294.925 (+-11.994)         |            3147.590 (+-6.808)           |     0.594 (+-0.000)      |         2521.228 (+-11.732)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |        758.352 (+-10.362)       |        5639.912 (+-14.495)         |            3014.123 (+-8.799)           |     0.534 (+-0.000)      |          756.114 (+-4.792)
      Input (1, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |        758.712 (+-5.781)        |         5927.541 (+-9.982)         |            3249.555 (+-7.226)           |     0.548 (+-0.000)      |          757.719 (+-5.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       1524.469 (+-12.860)       |        34321.641 (+-80.310)        |           19373.714 (+-56.351)          |     0.564 (+-0.000)      |         1518.082 (+-49.653)
      Input (1, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       1521.746 (+-13.780)       |        35949.711 (+-81.010)        |           21782.366 (+-68.938)          |     0.606 (+-0.000)      |         1467.911 (+-15.901)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |        712.311 (+-5.361)        |        38826.510 (+-92.267)        |           20762.314 (+-59.303)          |     0.535 (+-0.000)      |          712.669 (+-4.673)
      Input (1, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |        715.060 (+-4.757)        |        40269.353 (+-92.543)        |           22402.114 (+-81.574)          |     0.556 (+-0.000)      |          716.001 (+-8.945)

      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)       |       2331.889 (+-29.159)       |        21541.096 (+-72.346)        |           12181.194 (+-45.288)          |     0.565 (+-0.000)      |         2304.864 (+-21.351)
      Input (4, 3, 500, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)      |       2333.697 (+-10.066)       |        22514.154 (+-57.798)        |           21709.449 (+-98.307)          |     0.964 (+-0.000)      |         2302.141 (+-13.041)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)           |        1198.768 (+-5.364)       |       37652.371 (+-101.644)        |           42740.413 (+-98.571)          |     1.135 (+-0.000)      |          1197.104 (+-7.225)
      Input (4, 3, 500, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)          |        1196.851 (+-5.118)       |       39678.341 (+-173.750)        |           46807.738 (+-92.744)          |     1.180 (+-0.000)      |          1189.322 (+-5.681)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)     |       10020.978 (+-54.855)      |        19955.290 (+-71.891)        |           11420.521 (+-53.179)          |     0.572 (+-0.000)      |         9999.583 (+-61.230)
      Input (4, 3, 1200, 1300), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)    |       10066.441 (+-62.700)      |       21058.334 (+-183.414)        |           19986.577 (+-65.304)          |     0.949 (+-0.000)      |         10018.672 (+-59.188)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)         |       3171.135 (+-14.635)       |        19687.864 (+-54.320)        |           23313.699 (+-57.391)          |     1.184 (+-0.000)      |         3182.191 (+-17.686)
      Input (4, 3, 1200, 1300), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)        |       3181.314 (+-13.784)       |        20224.468 (+-50.827)        |          30541.963 (+-381.385)          |     1.510 (+-0.000)      |         3183.578 (+-16.203)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)       |       5879.450 (+-31.551)       |       136918.555 (+-480.320)       |          77723.568 (+-331.766)          |     0.568 (+-0.000)      |         5726.061 (+-87.517)
      Input (4, 3, 300, 400), torch.uint8, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)      |       5882.869 (+-30.325)       |       143378.094 (+-513.842)       |         137244.074 (+-4827.730)         |     0.957 (+-0.000)      |         5727.679 (+-22.164)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)           |       2674.937 (+-45.003)       |      244829.360 (+-1930.579)       |         271283.073 (+-2243.245)         |     1.108 (+-0.000)      |         2676.054 (+-24.632)
      Input (4, 3, 300, 400), torch.uint8, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)          |       2676.217 (+-16.601)       |      248658.668 (+-2904.952)       |         296514.520 (+-2983.281)         |     1.192 (+-0.000)      |         2682.844 (+-19.886)

      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |        1768.437 (+-6.294)       |        2934.013 (+-28.870)         |            2520.649 (+-6.797)           |     0.859 (+-0.000)      |          1759.292 (+-5.097)
      Input (1, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |        1748.660 (+-5.550)       |         3271.104 (+-7.557)         |            2891.306 (+-7.632)           |     0.884 (+-0.000)      |          1746.341 (+-5.845)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |        2813.150 (+-6.656)       |         3258.973 (+-7.543)         |            2766.286 (+-6.473)           |     0.849 (+-0.000)      |          2805.077 (+-7.611)
      Input (1, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |        2812.102 (+-8.211)       |         3568.780 (+-9.018)         |            3125.870 (+-7.324)           |     0.876 (+-0.000)      |          2834.178 (+-9.034)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |        1687.975 (+-9.527)       |         2752.085 (+-9.627)         |            2373.274 (+-7.888)           |     0.862 (+-0.000)      |          1698.782 (+-8.098)
      Input (1, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |        1696.606 (+-8.678)       |        3056.317 (+-13.303)         |           2699.160 (+-10.638)           |     0.883 (+-0.000)      |         1684.942 (+-10.519)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |        2613.491 (+-9.769)       |        3176.493 (+-13.366)         |            2730.193 (+-9.573)           |     0.859 (+-0.000)      |          2625.085 (+-9.943)
      Input (1, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       2614.946 (+-34.129)       |        3465.398 (+-11.165)         |           3044.396 (+-11.447)           |     0.879 (+-0.000)      |          2627.355 (+-9.608)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |       10784.549 (+-58.181)      |        18292.452 (+-59.344)        |           15909.922 (+-49.864)          |     0.870 (+-0.000)      |         10837.656 (+-51.947)
      Input (1, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |       10786.513 (+-52.308)      |        20449.038 (+-56.204)        |           18295.997 (+-54.522)          |     0.895 (+-0.000)      |         10843.751 (+-44.781)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |       17532.699 (+-64.807)      |        20425.699 (+-80.271)        |           17517.040 (+-79.705)          |     0.858 (+-0.000)      |         17595.597 (+-61.870)
      Input (1, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |       17530.816 (+-55.131)      |        22450.080 (+-92.899)        |           19827.828 (+-77.649)          |     0.883 (+-0.000)      |         17615.934 (+-71.716)

      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)     |       6875.484 (+-40.543)       |        11569.509 (+-62.462)        |          10053.350 (+-208.136)          |     0.869 (+-0.000)      |         6864.501 (+-46.747)
      Input (4, 3, 500, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)    |       6843.126 (+-44.498)       |        12915.236 (+-60.654)        |          25335.058 (+-382.640)          |     1.962 (+-0.000)      |         6899.002 (+-46.861)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (256, 256)         |       11103.418 (+-51.318)      |        28834.389 (+-78.395)        |          37405.463 (+-581.646)          |     1.297 (+-0.000)      |         11223.012 (+-60.709)
      Input (4, 3, 500, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (256, 256)        |       11092.994 (+-70.835)      |       36597.023 (+-118.988)        |           45761.267 (+-85.051)          |     1.250 (+-0.000)      |         11104.014 (+-61.288)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)   |       7106.791 (+-63.666)       |        11191.071 (+-45.402)        |           9786.037 (+-75.781)           |     0.874 (+-0.000)      |         7129.419 (+-77.674)
      Input (4, 3, 1200, 1300), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)  |       7146.519 (+-28.376)       |        12443.571 (+-39.425)        |           20147.067 (+-74.771)          |     1.619 (+-0.000)      |         7179.622 (+-64.847)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (200, 300)       |       10533.849 (+-44.227)      |       34814.909 (+-138.127)        |          42803.001 (+-114.326)          |     1.229 (+-0.000)      |         10644.039 (+-59.681)
      Input (4, 3, 1200, 1300), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (200, 300)      |       10548.910 (+-44.221)      |       42876.940 (+-146.959)        |          49711.443 (+-139.276)          |     1.159 (+-0.000)      |         10652.375 (+-44.174)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)     |      42814.521 (+-103.198)      |       73100.489 (+-435.262)        |          63587.659 (+-134.266)          |     0.870 (+-0.000)      |        43208.921 (+-195.287)
      Input (4, 3, 300, 400), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)    |      42812.373 (+-103.870)      |       81769.160 (+-373.369)        |         175159.813 (+-2028.558)         |     2.142 (+-0.000)      |         43007.691 (+-96.358)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (600, 700)         |      69955.505 (+-373.373)      |      215248.616 (+-2040.775)       |         267511.246 (+-2094.161)         |     1.243 (+-0.000)      |        70382.679 (+-594.941)
      Input (4, 3, 300, 400), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (600, 700)        |      69852.157 (+-490.076)      |      242841.484 (+-19645.513)      |         317931.678 (+-2016.498)         |     1.309 (+-0.000)      |        70074.819 (+-352.919)

Times are in microseconds (us).

[-------------------------------------------------------------------------------------------------------------------------------------------------------- Interpolate, cuda ---------------------------------------------------------------------------------------------------------------------------------------------------------]
                                                                                                                                                     |  Eager (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git0c61c20) PR  |  Compiled (2.4.0a0+git069270d) Nightly  |  speed-up PR vs Nightly  |  Eager (2.4.0a0+git069270d) Nightly
1 threads: ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |         97.727 (+-0.018)        |          97.765 (+-0.025)          |             97.773 (+-0.027)            |     1.000 (+-0.000)      |           97.905 (+-0.040)
      Input (1, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |         97.615 (+-0.066)        |          97.332 (+-0.032)          |             97.950 (+-0.026)            |     1.006 (+-0.000)      |           97.690 (+-0.062)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        100.635 (+-0.033)        |         125.883 (+-0.020)          |            102.499 (+-0.116)            |     0.814 (+-0.000)      |          101.103 (+-0.027)
      Input (1, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        100.898 (+-0.036)        |         109.717 (+-0.336)          |            102.558 (+-0.120)            |     0.935 (+-0.000)      |          101.642 (+-0.105)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)   |        462.853 (+-0.028)        |         382.475 (+-0.047)          |            382.472 (+-0.033)            |     1.000 (+-0.000)      |          462.188 (+-0.014)
      Input (4, 3, 2345, 2456), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)  |        462.783 (+-0.021)        |         382.806 (+-0.037)          |            382.563 (+-0.043)            |     0.999 (+-0.000)      |          462.089 (+-0.028)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (1234, 1345)       |        466.721 (+-0.022)        |         384.438 (+-0.027)          |            384.886 (+-0.037)            |     1.001 (+-0.000)      |          467.014 (+-0.025)
      Input (4, 3, 2345, 2456), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (1234, 1345)      |        466.993 (+-0.032)        |         384.212 (+-0.009)          |            383.946 (+-0.029)            |     0.999 (+-0.000)      |          466.575 (+-0.020)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        190.070 (+-0.082)        |         209.353 (+-1.096)          |            202.870 (+-0.888)            |     0.969 (+-0.000)      |          189.371 (+-0.164)
      Input (1, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        190.021 (+-0.018)        |         210.504 (+-0.456)          |            201.814 (+-0.770)            |     0.959 (+-0.000)      |          189.314 (+-0.036)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        188.860 (+-0.207)        |         336.635 (+-0.023)          |            252.026 (+-0.510)            |     0.749 (+-0.000)      |          188.860 (+-0.170)
      Input (1, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        188.725 (+-0.214)        |         276.329 (+-0.563)          |            251.439 (+-0.524)            |     0.910 (+-0.000)      |          188.776 (+-0.189)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)   |        781.879 (+-0.086)        |         836.389 (+-7.177)          |            816.483 (+-6.626)            |     0.976 (+-0.000)      |          781.362 (+-0.106)
      Input (4, 3, 1234, 1345), torch.float32, torch.contiguous_format | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)  |        781.824 (+-0.099)        |         840.406 (+-7.111)          |            807.530 (+-6.514)            |     0.961 (+-0.000)      |          781.307 (+-0.129)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: True, antialias: False, osize: (2345, 2456)       |        769.290 (+-0.309)        |         675.498 (+-1.537)          |            688.171 (+-4.326)            |     1.019 (+-0.000)      |          769.830 (+-0.222)
      Input (4, 3, 1234, 1345), torch.float32, torch.channels_last | mode: bicubic, align_corners: False, antialias: False, osize: (2345, 2456)      |        769.240 (+-0.179)        |         675.800 (+-1.113)          |            673.176 (+-1.740)            |     0.996 (+-0.000)      |          769.935 (+-0.171)

Times are in microseconds (us).

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120411
Approved by: https://github.com/lezcano
2024-03-29 13:15:25 +00:00
andrewor14
773ae817f7 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Differential Revision: [D54805279](https://our.internmc.facebook.com/intern/diff/D54805279)
Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-18 21:01:30 +00:00
PyTorch MergeBot
fd0dbcd891 Revert "Batch Norm Consolidation (#116092)"
This reverts commit 7b4f70eda5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/osalpekar due to Causes build failure in //caffe2:aten-hip (AMD build) target. See [D54707318](https://www.internalfb.com/diff/D54707318) for more details, may require internal build system changes to resolve. ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1989542965))
2024-03-11 22:22:41 +00:00
BowenBao
8c96b4367a Remove opmath cast for im2col decomp (#121363)
It is unclear why opmath cast is needed for im2col decomp, given that the decomposition is mainly performing padding, slicing, indexing and shape manipulation. There is no need for performing these operations in a higher precision, and in doing so it requires more memory and yields less performance.

Sample script to demonstrate inserted cast before this change

```python
import torch
from torch._decomp.decompositions import im2col

def func(x):
    return torch.nn.functional.unfold(
        x, kernel_size=[3, 1], padding=[2, 0], dilation=1, stride=1
    )

x = torch.rand(1, 1, 5, 5, dtype=torch.float16)

eo = torch._dynamo.export(
    func, aten_graph=True, decomposition_table={torch.ops.aten.im2col.default: im2col}
)(x)
eo.graph_module.print_readable()
```

```
class GraphModule(torch.nn.Module):
    def forward(self, x):
        arg0: "f16[1, 1, s0, s0]";

        arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        arg0_1 = arg0

        _to_copy: "f32[1, 1, s0, s0]" = torch.ops.aten._to_copy.default(arg0_1, dtype = torch.float32)
        ...
        constant_pad_nd: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.constant_pad_nd.default(_to_copy, [0, 0, 2, 2], 0.0);  _to_copy = None
        ...
        slice_1: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(constant_pad_nd, 0, 0, 9223372036854775807);  constant_pad_nd = None
        slice_2: "f32[1, 1, s0 + 4, s0]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
        index: "f32[1, 1, 3, s0 + 2, 1, s0]" = torch.ops.aten.index.Tensor(slice_2, [None, None, unsqueeze_5, add_3]);  slice_2 = unsqueeze_5 = add_3 = None
        permute: "f32[1, 1, 3, 1, s0 + 2, s0]" = torch.ops.aten.permute.default(index, [0, 1, 2, 4, 3, 5]);  index = None
        ...
        view: "f32[1, 3, s0**2 + 2*s0]" = torch.ops.aten.view.default(permute, [1, 3, mul]);  permute = mul = None
        _to_copy_1: "f16[1, 3, s0**2 + 2*s0]" = torch.ops.aten._to_copy.default(view, dtype = torch.float16);  view = None
        return pytree.tree_unflatten([_to_copy_1], self._out_spec)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121363
Approved by: https://github.com/lezcano
2024-03-09 15:37:27 +00:00
Boyuan Feng
35d3adb4b0 Add ATen Op _chunk_cat and _chunk_cat.out (#121081)
# Motivation

In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.

### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):

Input tensors:
```
AAAA   BBB   CC
AAAA   BBB
       BBB
```

Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```

### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):

Input tensors:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Reduce-scatter-copy-in first pad:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```

The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.

# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:

```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```

This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.

## Requirements on input

1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
2024-03-08 21:48:12 +00:00
andrewor14
7b4f70eda5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-08 15:07:15 +00:00
PyTorch MergeBot
b529c19bdf Revert "Batch Norm Consolidation (#116092)"
This reverts commit 5680f565d5.

Reverted https://github.com/pytorch/pytorch/pull/116092 on behalf of https://github.com/jeffdaily due to broke ROCm, PR signal was clean but trunk was not, the merge should have been blocked but wasn't ([comment](https://github.com/pytorch/pytorch/pull/116092#issuecomment-1981373237))
2024-03-06 17:10:01 +00:00
Tugsbayasgalan Manlaibaatar
5680f565d5 Batch Norm Consolidation (#116092)
**Summary:**

This commit simplifies the existing decomposition hierarchy
of batch norm ops by adding a single, backend agnostic op:
`batch_norm_with_update`. The existing hierarchy looks like:

```
aten.batch_norm ->
aten._batch_norm_impl_index ->
[
  aten.native_batch_norm ->
  aten._native_batch_norm_legit (export only) ->
  _batch_norm_legit_cpu/cuda (kernels, export only) ->
  _batch_norm_cpu/cuda (kernels)
] OR
[ aten.cudnn_batch_norm ] OR
[ aten.miopen_batch_norm ]
```

Aside from complexity, an important problem with the
above decomposition hierarchy is cuda numerics in
export flows. We observed significantly worse convergence
when training a mobilenetv2-like model when using the
`_batch_norm_cuda` kernel instead of the `cudnn_batch_norm`
kernel. This means users who export their models on CPU
first then move the models to cuda later may silently
see worse accuracies even when cudnn is installed,
because they are using the worse kernel. This issue is
summarized in https://github.com/pytorch/pytorch/issues/111384.

Instead, the new hierarchy proposed by consolidating
existing batch norm ops will look like:

```
aten.batch_norm ->
aten.batch_norm_with_update ->
[ _batch_norm_cpu (kernel) ] OR
[ _batch_norm_cuda (kernel) ] OR
[ cudnn_batch_norm (kernel) ] OR
[ miopen_batch_norm (kernel) ]
```

The new op `batch_norm_with_update` hides backend
implementation details and automatically picks the right
kernel based on what is installed. This commit also adds
the following variants to this op:

```
batch_norm_with_update_functional
batch_norm_with_update.out
batch_norm_no_update
batch_norm_no_update.out
batch_norm_backward
```

Note that this commit only adds this op and its variants,
but does not actually change the decomps to produce these
ops in the graph. This will be done after the 2 week FC
window, and the ops used in the old stack is planned to
be removed after the 6 month BC window.

Test Plan: `OpInfo` tests for `batch_norm_with_update`.

Reviewers: albanD, bdhirsh

Subscribers: albanD, bdhirsh, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/111384

Co-authored-by: Tugsbayasgalan Manlaibaatar <tmanlaibaatar@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116092
Approved by: https://github.com/bdhirsh, https://github.com/albanD
2024-03-06 04:50:46 +00:00
Jane Xu
da559c98e3 Fix isin decomp and add python meta registration (#120821)
Fixes #119792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120821
Approved by: https://github.com/malfet, https://github.com/peterbell10
2024-02-29 22:08:50 +00:00