The meta implementation for these _like function is wrong whenever device != "meta" (it doesn't fill the memory!).
zeros_like is special due to sparse and is fixed directly by always filling it with zeros.
Every other one is CompositeExplicit implementation, I went with removing their meta registration and tweaking code to avoid infinite recursions.
I can do the same as zeros_like (and add the proper filling for each) but that would duplicate the c++ logic and make the meta registrations non trivial. I can do it if you prefer to removal.
test_meta works fine with these fixes, relying on CI to see if other tests are breaking as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98160
Approved by: https://github.com/ezyang
Inductor codegen is suboptimal when calling all_reduce_coalesced with input args. We need to fix inductor's calling convention for that, or something else.
Might not work if any outputs is unused.
Test code:
```python
import torch
import torch.distributed as dist
import torch.nn.functional as F
from functorch import make_fx
import os
import torch.distributed._functional_collectives as ft_c
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
)
from torch._inductor.compile_fx import compile_fx_inner
def my_fun(a, b):
c = a * 3
tensors = ft_c.all_reduce_coalesced([a, c, b], "sum", [0])
return ((tensors[1] + tensors[0] + tensors[2]).sum(), )
@spawn_threads_and_init_comms(world_size=1)
def inductor_main(self):
x = torch.arange(4).cuda() * (dist.get_rank() + 1)
y = torch.arange(4).cuda() * (dist.get_rank() + 1)
x = x.to(torch.float)
y = y.to(torch.float) * 0.5
res = make_fx(my_fun)(x, y)
print(f"fx graph:\n{res.graph}")
ind = compile_fx_inner(res, [x, y])
print(f"inductor done:\n{ind}")
os.environ["PROXY_TENSOR_TRACING"] = "1"
os.environ["TORCH_COMPILE_DEBUG"] = "1"
torch._dynamo.config.output_code = True
if __name__ == "__main__":
inductor_main(None)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97157
Approved by: https://github.com/fegin
My first attempt was to apply the same solution as how proxy_tensor.py
handles other inplace ops. However, foreach is different in the way
that it's schema is `native_functions.yaml` does not return anything,
whereas ops like `addcmul_` and `addcdiv_` do return Tensors (Thanks
bdhirsh for teaching me this!). As a result, the proxy output
during tracing does not wrap anything, and hence we cannot correctly
connect it with subsequent operators. Modifying `native_functions.yaml`
is not a preferred solution. After discussing with bdhirsh, the
temporary solution is to do foreach functionalization as a graph
pass for now. Later, when https://github.com/pytorch/pytorch/issues/97852
is addressed, we will switch to default functionalization.
Edit: the latest version follows @bdhirsh 's suggestion on using
`make_fx` `decomposition_table` instead of implementing manual
fx.Graph tranforms to functionalize `_foreach_add_`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97853
Approved by: https://github.com/fegin, https://github.com/wanchaol
# Summary
There exists an optimization within the scaled_dot_product_efficieint bacwkard attention path to, under the right conditions, output grad_q, grad_k, grad_v all as aliases of the same storage. This was done to optimize for the hot path where mha does packed linear_projection -> chunk -> (view stuff) -> sdpa. The thought was that chunk-> would be able to "trivially" cat inputs to chunk.backward(). However upon closer inspection chunk.backward will call ` cat` irregardless of the inputs so this is not being utilized.
I validated this by profiling on main and then this branch and the traces produced the same both with `split.backward()` calling into cat.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96880
Approved by: https://github.com/cpuhrsch
# Summary
This PR adds an optional kwarg to torch torch.nn.functional.scaled_dot_product_attention()
The new kwarg is a scaling factor that is applied after the q@k.T step of the computation. Made updates to the efficient kernel to support but flash and math were minimally updated to support as well.
Will reduce the complexity of: #94729 and has been asked for by a couple of users.
# Review Highlights
- As far as I know I did this the correct way and this both BC and FC compliant. However I always seem to break internal workloads so I would love if someone can advice I did this right?
- I named the optional arg 'scale'. This is probably dumb and I should name it 'scale_factor'. I will make this change but this is annoying and it will require someone thinking we should rename.
- 'scale' is interpreted as `Q@K.T * (scale)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95259
Approved by: https://github.com/cpuhrsch
Fixes for PyTorch/XLA functionalization integration
---
Some notable changes include:
- More asserts in `FunctionalTensorWrapper`, so bugs show up more cleanly in cases where we e.g. forget to wrap an output
- Make the *_scatter ops `CompositeExplicitAutogradNonFunctional`, so we get a better error message and XLA doesn't accidentally try to us them
- Fix LTC/XLA codegen in core to handle multi-tensor out= ops with no returns
- Better erroring: Allow XLA to use the CPU fallback from core in a way so that it always errors on view ops, which XLA should no longer see.
- Update MetaConverter to exclude XLA tensors in raising NotImplemented…
- Add `_propagate_xla_data` op
- Add meta tensor support for some ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94537
Approved by: https://github.com/bdhirsh
Add _int_mm primitive that binds cuBLAS int8@int8 -> int32 matmul and that translates to Triton based mm templates under max autotune. This is a very useful first step towards better supporting quantization on the GPU. This is a not a user facing API, but an internal primitive.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94339
Approved by: https://github.com/ngimel, https://github.com/jansel
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1.
What's in the PR:
* nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question.
* The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise.
* PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`)
* Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`)
* I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful.
* Added `constrain_unify` to let you specify two unbacked SymInts must have the same value
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387
Approved by: https://github.com/voznesenskym
Summary:
This PR tries to decompose the operators in torch.ops.quantized_decomposed namespace to more
primitive aten operators, this would free us from maintaining the semantics of the quantize/dequantize
operators, which can be expressed more precises in terms of underlying aten operators
Note: this PR just adds them to the decomposition table, we haven't enable this by default yet
Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_q_dq_decomposition
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93312
Approved by: https://github.com/vkuzo, https://github.com/SherlockNoMad
Summary:
This PR tries to decompose the operators in torch.ops.quantized_decomposed namespace to more
primitive aten operators, this would free us from maintaining the semantics of the quantize/dequantize
operators, which can be expressed more precises in terms of underlying aten operators
Note: this PR just adds them to the decomposition table, we haven't enable this by default yet
Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_q_dq_decomposition
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93312
Approved by: https://github.com/vkuzo, https://github.com/SherlockNoMad
Fixes#92676
`arange` infers the output dtype from the argument types, but in order to reduce
falling back to ATen, inductor preferred to cast whole number float arguments to
int which gave the wrong output dtype. Instead, this decomposes floating point
arange into the prim equivalent for integers.
This also changes the signature of `prims.arange` to
```python
prims.iota(length, *, start, step, **factory_kwargs)
```
which only supports integers arguments. This is done because calculating the
output size from `start, end, step` is surprisingly complex and liable to off by
one errors so should not be duplicated in each backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93353
Approved by: https://github.com/ngimel, https://github.com/lezcano
# Summary
This PR creates _flash_attention_backward and _scaled_dot_product_flash_attention_backward native functions and registers them to the respective derivatives.yaml.
The goal is to replicate the torch.autograd.Function defined in the FlashAttention repo [here](33e0860c9c/flash_attn/flash_attn_interface.py (L126)) natively in PyTorch. One thing that we don't have access to is ctx.save_for_backward in native PyTorch so in order to save these variables I extended the returned objects from the forward functions.
### MetaFunctions
I also updated the FlashAttention meta functions to mirror the real outputs now. As well I added a meta registration for backwards. I have an XLMR training script and while eager training now works with FlashAttention compiling this module fails with the inductor error down below.
### Questions?
Performance issues vs mem efficient when using torch.nn.mha_forward
TorchCompile -> See purposed solution below.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92917
Approved by: https://github.com/cpuhrsch
# Summary
In preparation for pt 2.0 launch this PR updates SDPA's API and makes the function a nn.funcitonal public function.
## Changes
### API
Previously the the function signature was:
`scaled_dot_product_attention(query, key, value, attn_mask=None, need_attn_weights=False, dropout_p=0.0, is_causal=False) -> (Tensor, Tensor)`
Updated signature:
`scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False) -> Tensor`
This PR removes the need_attn_weights optional boolean variable and updates the return type to a singular tensor.
#### Reasoning:
The main goal of this function is to provide an easy interface for users to call into fused attention kernels e.g. (FlashAttention). The fused kernels do not currently support arbitrary attn_mask or dropout but there is a PR to mem-efficient attention to enable these. We want to have the API surface ready for when the backing kernels get updated.
The fused kernels save on memory usage by not materializing the weights and it is unlikely that a fast fused implementation will enable this feature so we are removing.
Discussed with folks at FAIR/Xformers and +1 this API change.
#### Make function Public
In preparation for the pt 2.0 launch we make the function public to start to generate user feedback
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92189
Approved by: https://github.com/cpuhrsch
# Summary
This PR updates the second return value from SDPA to return an empty tensor of size 0 not what it would be if need_attn_weights is True. Also updates the meta function to account for this change.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91782
Approved by: https://github.com/cpuhrsch
Found this issue from [weekly running 7k github models](https://github.com/pytorch/torchdynamo/issues/1884). This caused regression on pass rate, there are 25 models failed due to this issue.
The reason is argument ```cx``` of ```aten._cudnn_rnn``` can be ```None```, but it doesn't handle well in meta registration, so throws the following error:
```
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1059, in run_node
return nnmodule(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/module.py", line 1482, in _call_impl
return forward_call(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/nn/modules/rnn.py", line 477, in forward
result = _VF.rnn_tanh(input, hx, self._flat_weights, self.bias, self.num_layers,
File "/scratch/ybliang/work/repos/pytorch/torch/_subclasses/fake_tensor.py", line 916, in __torch_dispatch__
r = func(*args, **kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_ops.py", line 284, in __call__
return self._op(*args, **kwargs or {})
File "/scratch/ybliang/work/repos/pytorch/torch/_meta_registrations.py", line 2108, in _cudnn_rnn
cy = cx.new_empty(0 if cx is None else cell_shape)
AttributeError: 'NoneType' object has no attribute 'new_empty'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91333
Approved by: https://github.com/ezyang
It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
## Summary
Torch.compile was previously not working for transformerencoder because torch.SDPA calls a native function on tensors that returns an int. This PR instead creates a dispatch stub for the function called in order to not create a separate fx node for this native function.
As well this pr adds meta functions for the fused kerenels.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90576
Approved by: https://github.com/cpuhrsch
Previously, we hackily wrapped unspecialized integers into
tensors and treated them as tensor inputs. Sometimes, downstream
operations would not be able to deal with the tensor input. Now,
we wrap them into SymInt, so more correct overload selection occurs.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89639
Approved by: https://github.com/anjali411
Fake tensor behaves pretty differently depending on if you have
symbolic shapes or not. This leads to bugs; for example, we
weren't getting correct convolution_backward strides because we
bypassed the correct stride logic in fake tensor on symbolic
shapes.
This PR attempts to unify the two codepaths. I don't manage to
unify everything, but I get most of it. The algorithm is delicate
and I'm still hosing down test failures.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89038
Approved by: https://github.com/anjali411
This improves the memory compression of resnet18 from .84 -> .94 on inductor no-cudagraphs. It does mean that any extern kernel which incorrectly computes strides will be a hard error at runtime, but that's an issue we are going to have to face with dynamic shapes anyway. CC @ezyang, @SherlockNoMad
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88248
Approved by: https://github.com/ezyang
Fixes https://github.com/pytorch/torchdynamo/issues/1802
There are a few problems,
1. torch.fused_moving_avg_obs_fake_quant doesn't have OpInfo test
2. self.empty_like() is not a valid call. it should be torch.empty_like(self)
3. python meta function has some unexplained behavior for arguments with default value of bool type?
In particular, problem 3 is the most concerning one.
**UPDATE: This is expected behavior, see discussion below for explanation.**
Without setting the default value for `per_row_fake_quant` and `symmetric_quant`, it gets the following error when running with meta tensor.
```
meta__fused_moving_avg_obs_fq_helper() missing 2 required positional arguments: 'per_row_fake_quant' and 'symmetric_quant'
```
I can fix this by adding the default values to these two args. However, I observer something strange when examining the actual value in meta function.
```
print("per_row_fake_quant", per_row_fake_quant)
print("symmetric_quant", symmetric_quant)
```
When default values are False, printed value correctly reflect the args value populated from call site.
When default values are True, printed value is ALWAYS True, regardless of the populated value from call site.
When default Values are None, printed value is `None` when call site set the value to 'False', printed value is 'True' when call site sets the value to 'True'.
I also verify that this bug also affect for other meta function with default args....
My speculation is that this is something about pybind value packing when called from c++ dispatcher to python meta function, and default value parsing for python meta function (and other python dispatch functions) ?
I tried to find the c++ call stack, but gdb is missing symbols and C++ stacktrace is not working properly... Appreciate anyone who can point me to the source file for pybind value packing.
cc @ezyang
cc @bdhirsh. I know you had a fix in the symbolic shape branch...
cc @yanboliang who reported this bug
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88058
Approved by: https://github.com/bdhirsh, https://github.com/yanboliang
This is a policy update for meta registration. **We now prefer python meta implementation over C++ meta function.** This is a flip of the previous policy, where we prefer C++ meta function over python meta function if they both exist.
Here's the meta registration process:
1. register_meta and register_decomposition will place the python meta/decomp functions into the `global_decomp_table`. However, they will NOT register them into dispatcher.
2. After global_decomp_table is populated, we will compile an `active_meta_table`. For a given op, we pick the most specific decomp function from `global_decomp_table` in the preference order of Meta > PostAutograd > PreAutograd.
3. We will unconditionally register all of them into python dispatcher. And register them into C++ dispatcher, unless it one of the following 3 cases
- 1. the op is a CompositeImplicitAutograd, and should rely on decomposed op's meta
- 2. the op is a view op, as the MetaTensor doesn't support aliased storage
- 3. the op is in the blocklist (due to UT failures, and we will burn down this list op by op)
Over the long run, we wish to implement all meta functions in python. With this PR, 321 op_overloads will have cpp meta overridden by python meta. There are still 400 op_overloads is using cpp meta. The exact list can be found here https://gist.github.com/SherlockNoMad/d20bb736178df8eebd3b054c8bb7cdc5
cc @ngimel @jansel @lezcano @fdrocha @mlazos @soumith @voznesenskym @yanboliang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87426
Approved by: https://github.com/ezyang, https://github.com/jansel
`diag` was unnecessarily implemented as a kernel rather than as a composite
function, which made it unnecessarily difficult (explicit backward + all it entails).
We also change a few uses of `diag` on 2D tensors for `diagonal()`. The
latter returns a view rather than creating a new tensor.
We also upgrade its meta implementation to a fully-fledged
decomposition
I tried implementing the backwards of `diagonal()` via `diag_scatter` (or better `diag_scatter_` to keep the perf) but functionalisation was failing and I was not sure how to fix this, so I moved on. It may be possible to simplify that one as well if @soulitzer or someone knows how to do this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87180
Approved by: https://github.com/ngimel, https://github.com/albanD, https://github.com/mruberry
We recently fixed a bug on symbolic-shapes branch where
an isinstance(x, int) test failed when passed a SymIntNode.
To prevent this, I've added a lint for all the codepaths
where we may pass SymInt/SymFloat directly to reject
direct isinstance int/float tests, and instead use one of
the aliases. The lint rule explains the options. I then
go and fix all of them.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87345
Approved by: https://github.com/bdhirsh, https://github.com/albanD
Big-bang PR to symintify **all** .sizes() calls in derivatives.yaml, which will be needed for symbolic tracing.
* with the exception of `split()`, which is tougher to land because it requires internal changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86610
Approved by: https://github.com/albanD
This reverts commit 978b46d7c9.
Reverted https://github.com/pytorch/pytorch/pull/86488 on behalf of https://github.com/osalpekar due to Broke executorch builds internally with the following message: RuntimeError: Missing out variant for functional op: aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] . Make sure you have loaded your custom_ops_generated_lib
symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops
add meta_bernoulli_
meta kernel for at::gather
get pytorch_struct to pass: meta for scatter_add, fix backward
symintify split ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86488
Approved by: https://github.com/ezyang
symintify split_with_sizes, dropout, fused_fake_obs_quant. meta for padding_2d ops
add meta_bernoulli_
meta kernel for at::gather
get pytorch_struct to pass: meta for scatter_add, fix backward
symintify split ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86334
Approved by: https://github.com/ezyang
The output striding channels-last preservation logic differs between cuda and cpu. For the meta kernel, we can peek at the fake tensor device and use that to determine whether to do cpu or cuda.
You could argue there's a leaking of abstraction here but this seems like a pretty minimal leak and I'm not sure there's a much cleaner way forward for device-specific striding tracing logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82846
Approved by: https://github.com/ezyang
Fixes#79512
This PR adds support for convolutional meta modules and computes the output shape correctly for some meta input tensor.
Currently in progress, no tests written so far.
**Feature implementations**:
- [x] `Conv1d`
- [x] `Conv2d`
- [x] `Conv3d`
**Tests**:
- [x] `Conv1d`
- [x] `Conv2d`
- [x] `Conv3d`
cc @albanD @anjali411
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79834
Approved by: https://github.com/ezyang, https://github.com/albanD
When a function returns multiple parameters in PyTorch, the `out`
parameter takes a tuple of tensors (see `linalg.svd` for example).
The current implementation in `out_wrapper_multi` modelled this wrong,
as it assumed that it would take a number of different named
parameters.
This PR implements the correct behaviour in `out_wrapper`. As a small
side-effect, we now need to call `@out_wrapper()` when the output is
just one tensor.
This PR also implements an additional optional parameter that checks
whether the dtype of the given `out` is exactly the dtype that the meta
function requires. This is the behaviour that we currently have in
PyTorch, and this check is necessary in eager when we call with these
tensors into external libraries.
We also make the functions with several outputs return a namedtuple,
similar to what we do in PyTorch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79941
Approved by: https://github.com/mruberry, https://github.com/ezyang
This PR simplifies the logic of `linalg.qr` using structured kernels. I
also took this chance and merged a few `copy_` operations with other
ops.
This PR removes a the previous magma implementation as is never faster
than that of cusolver and it's rather buggy. This has the side-effect
that now `qr` is not supported in Rocm. Ivan confirmed that this is
fine, given how incredibly slow was QR on Rocm anyway (we were marking
some tests as slow because of this...).
This PR also corrects the dispatch in geqrf. Before, if we called it
with a matrix for which `input.size(-2) <= 256 && batchCount(input) >= std::max<int64_t>(2, input.size(-2) / 16)` is false, and we have cublas but not cusolver, we would end up calling magma rather than cublas. This is not what the heuristic suggested.
Probaly we should benchmark these heuristics again, but that's beyond the scope of this PR.
Note. It looks like `torch.geqrf` maybe broken in MAGMA as per the
previous comment in `linalg_qr_helper_magma`. IvanYashchuk wdyt?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79054
Approved by: https://github.com/IvanYashchuk, https://github.com/ezyang
This PR adds `linalg.lu_solve`. While doing so, I found a bug in MAGMA
when calling the batched MAGMA backend with trans=True. We work around
that by solving the system solving two triangular systems.
We also update the heuristics for this function, as they were fairly
updated. We found that cuSolver is king, so luckily we do not need to
rely on the buggy backend from magma for this function.
We added tests testing this function left and right. We also added tests
for the different backends. We also activated the tests for AMD, as
those should work as well.
Fixes https://github.com/pytorch/pytorch/issues/61657
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77634
Approved by: https://github.com/malfet
Decompositions can be used to fill in meta support where necessary,
assuming the operations they decompose to support meta key.
This PR adds register_meta kwarg to register_decomposition that
optionally lets you register the meta to the C++ dispatch table
for meta tensors. I use this to then get the meta function for
where and huber_loss for free.
Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77353
Approved by: https://github.com/mruberry