This patch fixes a corner case of `torch.isin` decompisition when both
inputs are scalars. This pattern showed up from #141196.
Fixes#141196.
Error stack befor this patch:
```
File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12503, in test_scalar_isin_decomposition
res = opt_f()
^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 691, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1618, in _call_user_compiler
raise BackendCompilerFailed(
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/output_graph.py", line 1593, in _call_user_compiler
compiled_fn = compiler_fn(gm, self.example_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
compiled_gm = compiler_fn(gm, example_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/__init__.py", line 2365, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_inductor/compile_fx.py", line 2317, in compile_fx
return aot_autograd(
^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/backends/common.py", line 106, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1179, in aot_module_simplified
compiled_fn = AOTAutogradCache.load(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/autograd_cache.py", line 923, in load
compiled_fn = dispatch_and_compile()
^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 1164, in dispatch_and_compile
compiled_fn, _ = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 576, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/aot_autograd.py", line 826, in _create_aot_dispatcher_function
compiled_fn, fw_metadata = compiler_fn(
^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 180, in aot_dispatch_base
fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2199, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1223, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_dynamo/eval_frame.py", line 872, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1278, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 720, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 419, in _functionalized_f_helper
f_outs = fn(*f_args)
^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 81, in inner_fn
outs = fn(*args)
^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 902, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/symbolic_shapes.py", line 7387, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 240, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/interpreter.py", line 320, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1326, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_subclasses/functional_tensor.py", line 511, in __torch_dispatch__
outs_unwrapped = func._op_dk(
^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 1428, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 797, in proxy_call
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/fx/experimental/proxy_tensor.py", line 2358, in maybe_handle_decomp
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_prims_common/wrappers.py", line 309, in _fn
result = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5108, in isin
return isin_default(elements, test_elements, invert=invert)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ryanguo99/repos/pytorch/torch/_decomp/decompositions.py", line 5137, in isin_default
x = elements.view(*elements.shape, *((1,) * test_elements.ndim))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
TypeError: view() received an invalid combination of arguments - got (), but expected one of:
* (torch.dtype dtype)
* (tuple of ints size)
While executing %isin : [num_users=1] = call_function[target=torch.isin](args = (%x, %x), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self):
# File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12498 in f, code: x = torch.tensor(0)
x: "i64[][]" = torch.tensor(0)
# File: /home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py:12499 in f, code: return torch.isin(x, x)
isin: "b8[][]" = torch.isin(x, x); x = None
return (isin,)
Original traceback:
File "/home/ryanguo99/repos/pytorch/test/dynamo/test_misc.py", line 12499, in f
return torch.isin(x, x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153216
Approved by: https://github.com/williamwen42, https://github.com/peterbell10
Currently, `torch._chunk_cat` only supports contiguous inputs (due to `.view()` usage in `_pad_chunk()` supporting only contiguous tensor). This doesn't work for internal models where there can be non-contiguous input tensors:
- size=[8192, 16416], stride=[16448, 1] # stride[0] is larger than size[1]
- size=[1152, 384], stride=[1, 1152] # column-major tensor
In this PR, we relax the assumption on contiguous input tensor, by switching from `.view()` to `.reshape()`. Note that since `.reshape()` will try to use `.view()` under the hood whenever possible, this should not cause regression to existing use cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151263
Approved by: https://github.com/BoyuanFeng
This is a resubmission of my previous PR that I accidentally deleted, apology in advance if any inconvenience caused. Below are details of this PR.
Fix an issue when torch.addmv behaves inconsistent between torch.compile mode and eager mode. Here is the code to reproduce:
```
import torch
import numpy as np
@torch.compile
def test_optimized(input, mat, vec):
return torch.addmv(input, mat, vec)
def test(input, mat, vec):
return torch.addmv(input, mat, vec)
input = torch.tensor([2], dtype=torch.int32)
mat = torch.tensor(np.random.randn(0, 0), dtype=torch.int32)
vec = torch.tensor([])
origin_out = test(input, mat, vec)
optimized_out = test_optimized(input, mat, vec)
print(origin_out) # tensor([2.])
print(optimized_out) # tensor([])
```
According to the equation (https://pytorch.org/docs/stable/generated/torch.addmv.html), when matrix and vector is empty, returning `[2.]` seems more reasonable to me.
Following the cpu implementation of this API:e97b97af56/aten/src/ATen/native/Blas.cpp (L62)
I add an additional branch to handle empty matrix
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143792
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
Fix: #141251
This PR adds a few static guard checks when decomposing and lowering the `slice`
operation, so that we avoid adding unnecessary guards. Specifically, when clamping the end
values.
In summary, the changes are:
- `slice` dynamo decomposition: checks `end >= sizes[dim]` statically. If we don't know
that, the following guard ensures that we (don't) need clamping.
- `evaluate_min` inductor `sizevar` function: checks whether we can solve it statically or
not, before actually creating a new guard.
The latter had to be changed because `evaluate_min` (called by `ir.SliceView` constructor)
would always try to create a guard based on the hints operation result. However, if both
`left` and `right` hints were true, it would default to `left <= right` guard. By checking
the guards statically before, we can avoid that.
```python
N = 16
@torch.compile(backend="inductor", dynamic=False, fullgraph=True)
def fn(x):
splits = torch.ops.aten.split.Tensor(x, N)
first = splits[0]
return torch.ops.aten.slice.Tensor(first, 0, 0, N)
x = torch.arange(N)
torch._dynamo.mark_dynamic(x, 0)
fn(x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142372
Approved by: https://github.com/ezyang
Fixes https://github.com/pytorch/pytorch/issues/141332
`F.logsigmoid` will return two outputs: `output` and `buffer`.
For `F.logsigmoid` cpu path, it will use buffer to store some intermediate values and use them when computing gradients, so it returns a `buffer` tensor with nonzero size. For cuda and xpu paths, buffer is useless, so the `buffer ` tensor size of xpu `F.logsigmoid` will be zero, just like cuda. The root cause of the issue is that the codes in `decompositions.py` (ref:https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py#L2803) only handle the cuda cases, when the a fake tensor with device is xpu run to here, it will use the cpu path and return a `buffer` with nonzero size, which is conflict to the implementation of intel xpu concrete tensor. Therefore this pr add conditions to handle xpu cases. Make sure the two returned buffer sizes match each other.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141333
Approved by: https://github.com/guangyey, https://github.com/EikanWang, https://github.com/ezyang
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
Previously the split decomp would return the input when there were no splits. this errors in torch.compile (or FakeTensorMode) with :
> RuntimeError: View operation returned a tensor that is the same as the input base tensor. This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
Fix for https://github.com/pytorch/pytorch/issues/133394
Differential Revision: [D65635070](https://our.internmc.facebook.com/intern/diff/D65635070)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140065
Approved by: https://github.com/bdhirsh
**Summary**
In the case of LLaMA2, for a linear operation with an activation size of `(4, 1, 4096)` and a stride of `(4096, 128, 1)` which has been decomposed into `matmul`. And the decomposition of `matmul` results in `bmm` due to a strict continuity check. We can align the continuity check with ATen by skip dim of size 1 to enable decomposition into `mm` instead.
**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_input_non_contiguous_3D_wo_bias
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139172
Approved by: https://github.com/jgong5, https://github.com/ezyang
Summary:
# context
* enable the `_get_user_embeddings` function
* run failed at P1610151892
```
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
GuardOnDataDependentSymNode: Could not guard on data-dependent expression u22 <= 0 (unhinted: u22 <= 0). (Size-like symbols: u22)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/38472faba4e3e6c1/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 1692, in native_layer_norm_backward
if M <= 0 or N <= 0:
```
```
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
)
```
# changes
* use guard_size_oblivious since the new_zeros return is kind of optimization, shouldn't impact the correctness of the follow up code logic.
* the size `ret[i][j]` could be zero, so the change in V1 isn't valid
* for more details: [post](https://fb.workplace.com/groups/6829516587176185/permalink/8003616173099548/)
```
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
```
# past
* found `u22` was introduced at
```
def _wait_impl(self) -> List[List[int]]:
# Can not use is_torchdynamo_compiling(), as every such condition should be independent for compilation with graph breaks.
if isinstance(self._splits_awaitable, dist.Work):
self._splits_awaitable.wait()
ret = self._output_tensor.view(self.num_workers, -1).T.tolist() # <------ u22 introduced here
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
for i in range(len(ret)):
for j in range(len(ret[i])):
torch._check_is_size(ret[i][j]) # <---------- my question: why the _check_is_size isn't enough??
torch._check(ret[i][j] > 0) # <------ added by diff V1
```
Test Plan:
# run command
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2 2>&1 | tee -a `tagT`.`tagH`.log
```
# results
* before
**without enabling `_get_user_embeddings`**
[14 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp2eNI7p/failures_and_restarts.html)
log: P1610151892
{F1889387940}
* V1
enable `_get_user_embeddings`
with `torch._check(ret[i][j] > 0)`
[13 Failures and Restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmp6J1iY9/failures_and_restarts.html)
{F1889388378}
* V2
enable `_get_user_embeddings`
with `if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):`
[tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpFhZZyC/index.html)
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
Differential Revision: D63424929
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136798
Approved by: https://github.com/ezyang
When tensor folding occurs during matmul operation returned tensor is a view. This can cause issues when matmul is used inside a custom function and such view is then returned as output. Then it cannot be modified inplace and causes errors.
It can be especially problematic when after such function inplace allreduce is performed.
Issue is resolved when unsafe_view is returned from matmul instead. This solution aligns matmul decomposition with eager implementation in such a way that a non view tensor is returned.
Test included in this PR reproduces the issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134568
Approved by: https://github.com/zou3519
Summary: Update SDPA decomposition to match updated stride from D62009189 which aligns strides with the `aten._scaled_dot_product_attention_math.default`, which makes `t.permute().continuous().permute()` no longer necessary.
Test Plan: CI
Differential Revision: D62278378
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135297
Approved by: https://github.com/drisspg
## Description
Create decomposition of _unsafe_index_put (non-core aten) that turns it into index_put (core aten)
## Testing
Phi3 mini + LoRA model successfully passed `to_edge` after failing due to a non-core aten `unsafe_index_put` getting introduced in a decomposition during joint graph calculations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133365
Approved by: https://github.com/pianpwk
Summary:
# context
* when running an IG FM training with PT2 we found there are a few graph break due to torch.diff call in [jagged_tensor.py](https://fburl.com/code/cwssxabc)
```
_length: List[int] = (
_length_per_key_from_stride_per_key(torch.diff(offsets), stride_per_key)
if variable_stride_per_key
else torch.sum(torch.diff(offsets).view(-1, stride), dim=1).tolist()
)
```
* look into the failure, we found the TORCH_CHECK in diff should be TORCH_SYM_CHECK
* slice_forward error: df3d7729e, [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpxXZ2em/index.html)
```
RestartAnalysis
Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression ((5*u37 + u38)//(u37 + u38)) < 0 (unhinted: ((5*u37 + u38)//(u37 + u38)) < 0). (Size-like symbols: u38, u37)
ATTENTION: guard_size_oblivious would fix the error, evaluating expression to False.
Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.
Potential framework code culprit (scroll up for full backtrace):
File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/e99934938a0abe90/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_decomp/decompositions.py", line 771, in slice_forward
if end_val < 0:
```
* after this diff: [tlparse](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpAhv2Sh/failures_and_restarts.html)
Test Plan:
# command
* run model
```
TORCH_SHOW_CPP_STACKTRACES=1 TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 TORCH_LOGS="+graph_code,output_code,dynamic,aot,guards,verbose_guards,recompiles,graph_breaks" TORCH_TRACE=/var/tmp/tt buck2 run fbcode//mode/opt fbcode//aps_models/ads/icvr:icvr_launcher_live -- mode=fmc/local_ig_fm_v4_mini training.pipeline_type=pt2
```
* generate tlparse
```
tlparse `ls -t /var/tmp/tt/* | head -1`
```
Reviewed By: ezyang
Differential Revision: D56339251
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133740
Approved by: https://github.com/ezyang