Commit Graph

482 Commits

Author SHA1 Message Date
Jesse Cai
5accae4197 [sparse] add extra options to _cslt_spare_mm (#137427)
Summary:

Splitting this PR into two, one for the cuSPARSELt improvements, and one
for the inductor lowering.

This PR adds in the additional cuSPARSELt bindings into pytorch.

* `torch._cslt_sparse_mm_search` will be deprecated in a future PR,
  so a warning has been added

* Added a header file for cuSPARSELtOps.cpp

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch, https://github.com/eqy
2024-11-27 05:32:45 +00:00
vasiliy
3d5fe0ce78 torch._scaled_mm: support dims of size 0 for tensorwise scaling (#140967)
Summary:

Ensures we support dims of size 0 properly in `torch._scaled_mm`. Follows the behavior from `torch.mm`.

For now only enable support for tensorwise, we can tackle rowwise in a future PR.

Test Plan:

```
python test/test_matmul_cuda.py -k test_zero_dim
```

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140967
Approved by: https://github.com/eqy, https://github.com/drisspg
2024-11-27 04:07:52 +00:00
Joel Schlosser
8ba555ec8a Fix where() for NJT (#141500)
**Background:** It's common to use `scalar_tensor()` in the input to `where()` to convert any scalars present to compatible tensors with matching options, *including layout*. This shows up in various places, notably including derivative formulas ([example](78491d6afc/tools/autograd/derivatives.yaml (L432-L434))). It causes problems for NJTs because they have `layout=torch.jagged` and it never makes sense to create a scalar tensor with this layout. Some of the breakage only seems to happen in CI for reasons I don't fully understand (see the revert of #140736 due to softshrink's derivative formula).

**This PR:**
* Allows non-contiguous NJT inputs to `where()` + adds tests for this
* Handles scalar tensor / dense tensor inputs for `condition` / `other` + adds tests for this
    * Uses limited `broadcast_tensors()` / `broadcast_to()` support
    * Improves `expand()` to work on non-contig NJTs
* Changes `scalar_tensor()` to use `torch.strided` instead of `torch.jagged` in both eager and torch.compile (i.e. meta registration)
* Changes backward formulas for `sinc`, `pow`, `special.i1`, and `special.i1e` to uses `scalar_tensor()` instead of e.g. `zeros({})`

**Alternative approach:** Update all problematic usages of `scalar_tensor()` to avoid ever passing `layout=torch.jagged`. This is an extensive change and includes `torch.where()` logic, a bunch of derivative formulas, and likely other places not yet discovered.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141500
Approved by: https://github.com/malfet, https://github.com/cpuhrsch, https://github.com/soulitzer
2024-11-26 20:13:27 +00:00
PyTorch MergeBot
5318bf8baf Revert "[sparse] add extra options to _cslt_spare_mm (#137427)"
This reverts commit f1451163ec.

Reverted https://github.com/pytorch/pytorch/pull/137427 on behalf of https://github.com/huydhn due to This looks like the test is still failing, plz do a rebase ([comment](https://github.com/pytorch/pytorch/pull/137427#issuecomment-2499918590))
2024-11-26 08:01:24 +00:00
Jesse Cai
f1451163ec [sparse] add extra options to _cslt_spare_mm (#137427)
Summary:

Splitting this PR into two, one for the cuSPARSELt improvements, and one
for the inductor lowering.

This PR adds in the additional cuSPARSELt bindings into pytorch.

* `torch._cslt_sparse_mm_search` will be deprecated in a future PR,
  so a warning has been added

* Added a header file for cuSPARSELtOps.cpp

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch, https://github.com/eqy
2024-11-25 23:45:41 +00:00
PyTorch MergeBot
cc90ba8924 Revert "[sparse] add extra options to _cslt_spare_mm (#137427)"
This reverts commit 45b30a5aec.

Reverted https://github.com/pytorch/pytorch/pull/137427 on behalf of https://github.com/huydhn due to Sorry for reverting your change but test_sparse_semi_structured is failing in trunk after it lands ([comment](https://github.com/pytorch/pytorch/pull/137427#issuecomment-2494047577))
2024-11-22 15:40:21 +00:00
Jesse Cai
45b30a5aec [sparse] add extra options to _cslt_spare_mm (#137427)
Summary:

Splitting this PR into two, one for the cuSPARSELt improvements, and one
for the inductor lowering.

This PR adds in the additional cuSPARSELt bindings into pytorch.

* `torch._cslt_sparse_mm_search` will be deprecated in a future PR,
  so a warning has been added

* Added a header file for cuSPARSELtOps.cpp

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch, https://github.com/eqy
2024-11-21 23:37:36 +00:00
Yukio Siraichi
216b6a952c triangular_solve: fix meta function output argument dtype check. (#140286)
Tracking issue: #138399
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140286
Approved by: https://github.com/ezyang
ghstack dependencies: #140186
2024-11-14 15:25:14 +00:00
pralay
f06ee3e546 [pt2] Add meta for _add_relu (#140009)
aten._add_relu doesn't have meta function registered, so in dynamic shape case it is throwing an error in dynamo logs:
Error:
`V1107 11:25:32.344000 140481543555072 torch/_dynamo/symbolic_convert.py:534] [0/1] [__graph_breaks] NotImplementedError: aten::_add_relu.Tensor: attempted to run this operator with Meta tensors, but there was no fake impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add a fake impl.`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140009
Approved by: https://github.com/ezyang
2024-11-13 06:30:58 +00:00
Yukio Siraichi
c182c7ccfc Fix triangular_solve meta function out parameter names. (#140186)
This PR replaces the parameter names specified in the `triangular_solve_meta`
function (specifically in its `@out_wrapper(...)` decorator) by those written in the
_native_functions.yaml_ file.

This name mismatch caused the operation to fail when using the meta device (see error
below):

```python
Traceback (most recent call last):
  File "examples/test.py", line 23, in <module>
    torch.triangular_solve(b.to("meta"), A.to("meta"), out=meta_out)
  File "torch/_decomp/__init__.py", line 100, in _fn
    return f(*args, **kwargs, out=None if is_none else out_kwargs)
  File "torch/_prims_common/wrappers.py", line 289, in _fn
    result = fn(*args, **kwargs)
TypeError: triangular_solve_meta() got an unexpected keyword argument 'X'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140186
Approved by: https://github.com/ezyang
2024-11-12 19:04:34 +00:00
Jiang, Yanbing
f77eb07662 Split int4wo weight packing (#139611)
Fixes https://github.com/pytorch/ao/issues/1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139611
Approved by: https://github.com/jerryzh168
2024-11-12 10:12:50 +00:00
Colin Peppler
63b01f328e [inductor] support masked_scatter w/ unbacked sized source (#138083)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138083
Approved by: https://github.com/jansel
2024-11-06 02:16:25 +00:00
Pian Pawakapan
a678eaf1ad check fake/real mismatches during real tensor prop (#137747)
Summary:
While testing exportability for PT2 Inference models, we found various cases of invalid op inputs during tracing, for example errors like: `a and b must have same reduction dim`, `expected scalar type Long but found Int`, etc. Looking more closely, these happened to due the same few meta kernels & eager kernels producing mismatched outputs upstream (e.g. different output tensor dtype, int output).

Adding checks to catch mismatched outputs in real tensor prop upstream, so errors are raised at the mismatched op, instead of the downstream ops taking them as inputs. Relies a lot on utils from [CrossRefFakeMode](929797dedb/torch/_subclasses/fake_utils.py (L78))

Follow ups: could add more checks, and maybe have a flag to only enable these for cases like draft mode, so perf doesn't suffer?

Test Plan: test_export, test_fake_tensor

Differential Revision: D64210055

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137747
Approved by: https://github.com/zou3519
2024-11-04 23:39:48 +00:00
PyTorch MergeBot
8197e4c70d Revert "[sparse] add search for optimal alg_id to torch.compile (#137427)"
This reverts commit 39bfba3f56.

Reverted https://github.com/pytorch/pytorch/pull/137427 on behalf of https://github.com/jcaip due to this PR breaks AO tests ([comment](https://github.com/pytorch/pytorch/pull/137427#issuecomment-2435906592))
2024-10-24 17:27:06 +00:00
Laith Sakka
ed313a5ca2 Introduce torch.sym_add, variadic add (#138660)
Tested internally here: https://www.internalfb.com/diff/D64057744
This is a reland after previous internal failures.
main change is
```
 if min is None and max is None:
        torch._check_is_size(size)
        return
```

Partially addresses https://github.com/pytorch/pytorch/issues/128150

When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation.  Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments.  Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138660
Approved by: https://github.com/ezyang, https://github.com/bobrenjc93
2024-10-23 17:42:41 +00:00
Jesse Cai
39bfba3f56 [sparse] add search for optimal alg_id to torch.compile (#137427)
Summary:

This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`

Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">

* `torch._cslt_sparse_mm_search` has been modified to return optimal
  split-k parameters as well as max alg_id.

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
2024-10-22 22:39:42 +00:00
Will Feng
1a8b4c65ac Fix scatter and gather shape check error message (#138310)
The error message seems incorrect based on the surrounding code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138310
Approved by: https://github.com/Microve, https://github.com/fegin
2024-10-18 07:49:07 +00:00
Yukio Siraichi
030ba03681 Add meta functions for lerp, addcmul, and addcdiv. (#136909)
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their
respective inplace versions).

These functions only had refs implementations, which was being the root cause of a
significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA
backend. Running the meta functions resulted in the following improvements:

- `lerp` calls: 1,550ms to 140ms (10x)
- `addcdiv` calls: 640ms to 350ms (1.8x)
- `addcmul` calls: 620ms to 300ms (2.05x)

[1]: https://github.com/pytorch/xla/issues/7923

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136909
Approved by: https://github.com/jansel
2024-10-12 12:40:46 +00:00
Angel Yang
a777dea3b3 Remove dtype check on meta device (#136774)
Summary:
# Latest Update

This diff is no longer needed because we did need the check to exist, to make meta behave the same as other devices, see D54526190.

---------------------------------

# Background

T176105639

| case | embedding bag weight | per_sample_weight | fbgemm lookup | forward in meta |
| A | fp32 | fp32 | good | good |
| B | fp16 | fp32 | good| failed [check](https://fburl.com/code/k3n3h031) that forces weight dtype ==  per_sample_weights dtype |
| C | fp16 | fp16 | P1046999270, RuntimeError: "expected scalar type Float but found Half from fbgemm call" | good |
| D | fp32 | fp16 | N/A | N/A |

Currently we are in case A. Users need to add `use_fp32_embedding` in training to force embedding bag dtype to be fp32. However, users actually hope for case B to use fp16 as the embedding bag weight. When deleting `use_fp32_embedding`, they would fail the [check](https://fburl.com/code/k3n3h031) that forces `weight dtype ==  per_sample_weights dtype ` in meta_registration.

The check is actually not necessary. Is it because the backend fbgemm does support case B. Additionally, later on in the `meta_embedding_bag`, `weight` and `per_sample_weights` don't need to be in the same dtype (https://fburl.com/code/q0tho05h, weight is src, per_sample_weights is scale) for `is_fast_path_index_select`.

# This diff
Therefore, this diff remove the unnecessary [check](https://fburl.com/code/k3n3h031) to support case B in meta forward. With such, users are able to use fp16 to be the emb bag dtype without the need to force per_sample_weights the same dtype in meta forward (see Test Plan).

# Reference diffs to resolve this issue
Diff 1: D52591217
This passes embedding bag dtype to feature_processor to make per_sample_weights same dtype as emb bag weight. However, `is_meta` also needs to be passed because of case C. fbgemm still does not support per_sample_weights = fp16 (see the above table). Therefore users are forced to only make per_sample_weights fp16 when it is on meta. The solution requires too many hacks.

Diff 2: D53232739
Basically doing the same thing in diff 1 D52591217, except that the hack is added in TorchRec library. This adds an if in EBC and PEA for: when emb bag weight is fp16, it forces per_sample_weight fp16 too. However, it would then result in fbgemm issue too and has broken a bunch of prod models.

Test Plan:
# APS
The following command will run icvr_launcher which triggers ads_launcher and run forward in meta device:
```
buck2 run mode/opt -c python.package_style=inplace //aps_models/ads/icvr:icvr_launcher_publish -- mode=mast_ig_fm_when_combo0_uhm_publish launcher.fbl_entitlement=ads_global_tc_ads_score launcher.data_project=oncall_ads_model_platform launcher.tags=[ads_ranking_taxonomy_exlarge_fm_prod] stages.train=false
```

Result:
 {F1461463993}

Reviewed By: ezyang

Differential Revision: D54175438

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136774
Approved by: https://github.com/ezyang
2024-10-12 05:45:21 +00:00
PyTorch MergeBot
16a2c2cfd4 Revert "Introduce torch.sym_sum (#136429)"
This reverts commit 90bed32b98.

Reverted https://github.com/pytorch/pytorch/pull/136429 on behalf of https://github.com/ezyang due to fails internal stuff ([comment](https://github.com/pytorch/pytorch/pull/136429#issuecomment-2403335147))
2024-10-09 20:08:01 +00:00
Brian Hirsh
53af729a66 add meta for _segment_reduce_backward (#137442)
reland of https://github.com/pytorch/pytorch/pull/124988

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137442
Approved by: https://github.com/albanD
2024-10-08 18:40:06 +00:00
Edward Z. Yang
90bed32b98 Introduce torch.sym_sum (#136429)
Partially addresses https://github.com/pytorch/pytorch/issues/128150

When you have big sums of values, we end up computing long chains of
binary addition in our FX graph representation.  Not only is this ugly,
it also is quadratic, as the sympy.Add constructor is O(N) in number
of arguments.  Instead, ensure that we maintain the summation as a
single FX node so we can do the entire addition all in one go.

update_hint_regression benchmark, before and after:

```
update_hint_regression,compile_time_instruction_count,2648328980
update_hint_regression,compile_time_instruction_count,2563748678
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136429
Approved by: https://github.com/isuruf
2024-10-08 18:12:57 +00:00
Benjamin Glass
a968576777 Add lowering for aten.searchsorted (#135701)
Adds lowering for `aten.searchsorted`. This entails:

1. Adding support for multi-dimensional bucket tensors to `ops.bucketize`.
2. Adding support for striding to `ops.bucketize`.
3. Adding support for sorting tensors to `ops.bucketize`.
4. Adding a lowering for `aten.searchsorted.Tensor`.
5. Adding a basic decomposition for `aten.searchsorted.Scalar` that calls into the lowering for tensors.
6. Updating the meta-function for `aten.searchsorted` to properly check some of the sizing conditions.

Closes #135873

Differential Revision: [D63766514](https://our.internmc.facebook.com/intern/diff/D63766514)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135701
Approved by: https://github.com/amjames, https://github.com/eellison, https://github.com/davidberard98
2024-10-04 19:26:05 +00:00
PyTorch MergeBot
f56f7476d3 Revert "Add meta functions for lerp, addcmul, and addcdiv. (#136909)"
This reverts commit e4b98b1149.

Reverted https://github.com/pytorch/pytorch/pull/136909 on behalf of https://github.com/albanD due to breaks trunk jobs ([comment](https://github.com/pytorch/pytorch/pull/136909#issuecomment-2393774694))
2024-10-04 14:01:54 +00:00
Yukio Siraichi
e4b98b1149 Add meta functions for lerp, addcmul, and addcdiv. (#136909)
This PR adds new meta functions for `lerp`, `addcmul`, and `addcdiv` (including their
respective inplace versions).

These functions only had refs implementations, which was being the root cause of a
significant overhead ([issue][1]) when running `AdamW` optimizer step on PyTorch/XLA
backend. Running the meta functions resulted in the following improvements:

- `lerp` calls: 1,550ms to 140ms (10x)
- `addcdiv` calls: 640ms to 350ms (1.8x)
- `addcmul` calls: 620ms to 300ms (2.05x)

[1]: https://github.com/pytorch/xla/issues/7923

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136909
Approved by: https://github.com/jansel
2024-10-04 02:47:25 +00:00
Isuru Fernando
0c936c3ecb Add decomps for max_unpool (#133146)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133146
Approved by: https://github.com/amjames, https://github.com/eellison
2024-09-20 21:35:25 +00:00
Duygu Altinok
775517693a Add type checks for Tensor.add_ (#135864)
Fixes  #127049

There's already a meta func in `meta_registrations.py` for `add_` and `sub_` methods. I added a second meta function for error checking, i.e `int.add/sub_(float)` and `bool.add/sub_(other types)` .

Also the corresponding test with Dynamo passes, removed `@xfailIfTorchDynamo`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135864
Approved by: https://github.com/williamwen42
2024-09-19 03:09:36 +00:00
Aaron Gokaslan
b491e2974c [BE][Ez]: Add full half/bfloat16 dtype for unique and isin (#136114)
Fixes #136090

* Add support for isin to tensor half dtypes for CPU (just add a few extra dispatches).
* Seems like the CUDA implementation for bfloat16 was mostly compiled and available all along (it just calls sort internally AND unique). To enable it, we just need to remove an assert to access it (since sort's functionality was updated since the assert was added) and add missing dtype support to unique.
* This unlocks more GPU functionality with minimal code bloat. I also added CPU kernels for the dtypes for parity.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136114
Approved by: https://github.com/malfet
2024-09-16 17:49:12 +00:00
Joel Schlosser
525bec804c NJT <-> padded dense conversions (#125947)
This PR:
* Implements the pre-existing `nt.to_padded_tensor(padding_val)` ATen op via the FBGEMM kernel + appropriate view gymnastics (since that kernel only handles 2D values)
* Introduces a new `_nested_from_padded_tensor` op for the reverse conversion, implemented via the reverse FBGEMM kernel + view gymnastics
    * Note: there is currently no public API for this; design booted to a future PR

TODO:
* ~~Propagate min / max sequence length via the new factory function `_nested_from_padded_tensor`~~
* ~~Verify that Inductor does computation fusion via test logic~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125947
Approved by: https://github.com/soulitzer
2024-09-12 17:54:25 +00:00
Amadeusz Skrzypczak
0226fcaacf Disable cuda specific restrictions in _scaled_mm for other devices (#135579)
Fixes #135576

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135579
Approved by: https://github.com/drisspg
2024-09-11 11:05:38 +00:00
Valentine233
0dbc72887b [CPU][flash attention] make the stride of output align with input (#134656)
Fixes #133671

Currently, the output of CPU flash attention has a fixed layout, no matter what the input is. This PR makes the stride of output align with input q/k/v, which is the same behavior as math backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134656
Approved by: https://github.com/jgong5, https://github.com/drisspg
2024-08-29 16:04:25 +00:00
David Berard
289486d007 Move attention kernels back from fake_impls to meta_registrations (#134288)
See #121528 for additional context.

In #120682, we moved the attention kernels from meta_registrations to fake_impls with the intent of fixing the device handling for seed/offset: these are typically on CPU. We needed to put the registrations in fake_impls to do this because meta_registrations doesn't have a way to specify device, whereas fake_impls does. But when we tried to actually fix the device types (#120839), we had to revert the PR because it broke cudagraph handling (during which seed/offset _are_ on CUDA).

Now, we want to put the registrations back in meta_registrations so that we can call these kernels with meta tensors. The use case is later in this stack - we want to be able to use the flop counter with these kernels.

Also - I specifically skip the `compare_tensor_meta()` check in test_fake / test_fake_autocast tests for the `_efficient_attention_forward` and `_flash_attention_forward` kernels, which fails because of the device mismatch from the seed/offset tensors. Then we can un-skip these opinfos. I verified that the efficient_attention_forward bug (#120842) is now caught by these opinfos if I revert the fix from this PR.

Differential Revision: [D61687369](https://our.internmc.facebook.com/intern/diff/D61687369)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134288
Approved by: https://github.com/drisspg
2024-08-27 21:10:36 +00:00
Amadeusz Skrzypczak
38f97ec8e3 [pt2] Add meta for poisson (#134103)
Because aten.poisson doesn't have meta function registered, there is one additional eager execution of this op during compilation phase of torch.compile.

There are more ops without meta registration. Is there any reason for it?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134103
Approved by: https://github.com/ezyang
2024-08-26 06:14:38 +00:00
Andrew Gu
b0803129e8 Added meta registration for _fused_adamw_ (#133728)
See https://github.com/pytorch/pytorch/issues/123461#issuecomment-2294335273

<img width="1463" alt="Screenshot 2024-08-16 at 5 38 25 PM" src="https://github.com/user-attachments/assets/fe940c0e-775f-4047-bf69-34a3677d539b">
same signature so should be ok to just add the op to the decorator
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133728
Approved by: https://github.com/janeyx99, https://github.com/fegin
2024-08-17 00:28:31 +00:00
Xuehai Pan
758a0a88a2 [BE][Easy] enable ruff rule PIE790: unnecessary pass statement (#133200)
This PR removes unnecessary `pass` statement. This is semanticly safe because the bytecode for the Python code does not change.

Note that if there is a docstring in the function, a empty function does not need a `pass` statement as placeholder.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133200
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/kit1980
2024-08-15 15:50:19 +00:00
Xuehai Pan
4226ed1585 [BE] Format uncategorized Python files with ruff format (#132576)
Remove patterns `**`, `test/**`, and `torch/**` in `tools/linter/adapters/pyfmt_linter.py` and run `lintrunner`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132576
Approved by: https://github.com/ezyang, https://github.com/Skylion007
ghstack dependencies: #132574
2024-08-04 17:13:31 +00:00
Siyu Yang
882d80fd92 Add lowering for updated _scaled_mm (fixing submodules) (#130422)
Add the Inductor lowering for `torch._scaled_mm`, whose API was last updated in https://github.com/pytorch/pytorch/pull/128683.

The lowering does:
- for tensor-wise scaling, auto-tune between the default ATen kernel (cuBLAS) and Triton kernel configurations.
- for row-wise scaling, auto-tune between the default ATen kernel (CUTLASS kernel added in https://github.com/pytorch/pytorch/pull/125204) and Triton kernel configurations.

The Triton kernel template is based on 3ad9031d02 (D56337896) by @choutim, without using SPLIT_K, and that of mm `torch/_inductor/kernel/mm.py`

## Testing:
- Logging shows max-autotune tuning (`AUTOTUNE scaled_mm`) for both tensor-wise and row-wise scaling when called with the two scaling types.
- Row-wise scaling allows operator fusion between preceding pointwise/reduction op and amax/cast:
    - output code Evaluating m=256, n=256, k=256, fusion_case='pointwise', scaling_mode='row'
        - P1477224245 - 2 kernels
    - output code Evaluating m=2048, n=256, k=2048, fusion_case='reduction', scaling_mode='row'
        - P1477227340 - 2 kernels

- UT `python test/inductor/test_fp8.py -- TestFP8Lowering`

## Benchmarking

Eager/compiled tensor-wise/row-wise scaling for various shapes:
https://docs.google.com/spreadsheets/d/1VfWEVuyrwoWysfbS0_u2VHJ-PsdWkF1qIsiD60AzTes/edit?gid=2113587669#gid=2113587669
- Some of the “compiled” cases are slightly slower than “eager”. It’s because max-autotune selected the ATen kernel in the compiled case, and I think the discrepancy is variance.

Eager/compiled tensor-wise/row-wise scaling with pointwise/reduction preceding op for various shapes:
https://docs.google.com/spreadsheets/d/1Nv07NrdffQIoDeMjo9E0V-E-EYrEN0WysO_bn1bc6ns/edit?gid=1715488446#gid=1715488446

## Questions for reviewers:
- Should the type of the accumulator `ACC_TYPE` always be in float32? If not, where is this type set (output layout?)?

## Todo:
- Make the Triton template use the improved persistent kernel version (https://github.com/pytorch/FBGEMM/pull/2735 by @htyu)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130422
Approved by: https://github.com/ipiszy
2024-07-30 23:48:48 +00:00
PyTorch MergeBot
fd5b7d4bf9 Revert "[BE] typing for decorators - _meta_registrations (#131572)"
This reverts commit bfe0079b72.

Reverted https://github.com/pytorch/pytorch/pull/131572 on behalf of https://github.com/clee2000 due to breaking lint internally D60265575 ([comment](https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359))
2024-07-28 03:29:32 +00:00
Jiang, Yanbing
bceb91222c Fix meta error in _convert_weight_to_int4pack (#130915)
This PR is to fix meta error in _convert_weight_to_int4pack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130915
Approved by: https://github.com/jerryzh168
2024-07-26 08:36:30 +00:00
Aaron Orenstein
bfe0079b72 [BE] typing for decorators - _meta_registrations (#131572)
See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131572
Approved by: https://github.com/oulgen, https://github.com/zou3519
ghstack dependencies: #131568, #131569, #131570, #131571
2024-07-25 22:24:19 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Isuru Fernando
bb4251213b Add decomposition for channel_shuffle (#118775)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118775
Approved by: https://github.com/peterbell10
2024-07-20 01:24:41 +00:00
Xuehai Pan
b29b23137c [Easy] Fix argument name collision in dispatched functions (#129562)
Use positional-only argument to avoid naming collision with aten ops arguments that are named "self".

```python
In [1]: def foo(self, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [2]: def bar(self, /, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [3]: foo(1, 2, self=3)
TypeError: foo() got multiple values for argument 'self'

In [4]: bar(1, 2, self=3)
1
(2,)
{'self': 3}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129562
Approved by: https://github.com/zou3519, https://github.com/fegin
2024-07-17 14:39:56 +00:00
Jiang, Yanbing
93a03edcf9 Update error message in meta__convert_weight_to_int4pack (#130707)
This PR is to fix error message in https://github.com/pytorch/pytorch/pull/129940.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130707
Approved by: https://github.com/lezcano, https://github.com/malfet
2024-07-16 00:44:35 +00:00
Colin Peppler
a7f54c7f8a [dynamo] add meta fn for aten.kthvalue.default (#130562)
I saw
```
torch._dynamo.exc.Unsupported: unsupported operator: aten.kthvalue.default
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130562
Approved by: https://github.com/jingsh, https://github.com/zou3519
2024-07-12 23:48:31 +00:00
Jiang, Yanbing
6f662e9575 update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.

Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.
![image](https://github.com/pytorch/pytorch/assets/61222868/64971c4b-29b9-42cf-9aeb-ffa01cea93dd)

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.
![image](https://github.com/pytorch/pytorch/assets/61222868/c7990761-c723-417b-aca2-7c60db7785c7)

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.
![image](https://github.com/pytorch/pytorch/assets/61222868/6046dcf4-920b-4c63-9ca3-1c8c3cafebde)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
2024-07-11 15:26:48 +00:00
PyTorch MergeBot
637cc8d27f Revert "update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)"
This reverts commit 6367f02a0e.

Reverted https://github.com/pytorch/pytorch/pull/129940 on behalf of https://github.com/albanD due to Broke rocm tests on main 6367f02a0e ([comment](https://github.com/pytorch/pytorch/pull/129940#issuecomment-2220554681))
2024-07-10 13:48:32 +00:00
Jiang, Yanbing
6367f02a0e update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.

Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.
![image](https://github.com/pytorch/pytorch/assets/61222868/64971c4b-29b9-42cf-9aeb-ffa01cea93dd)

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.
![image](https://github.com/pytorch/pytorch/assets/61222868/c7990761-c723-417b-aca2-7c60db7785c7)

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.
![image](https://github.com/pytorch/pytorch/assets/61222868/6046dcf4-920b-4c63-9ca3-1c8c3cafebde)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
2024-07-10 07:38:42 +00:00
Yukio Siraichi
a79bb8db91 Make _embedding_bag_backward explicitly dispatch to CPU and CUDA. (#129691)
This PR modifies `_embedding_bag_backward` item inside _native_functions.yaml_, so that it
dispatches to CPU and CUDA directly, instead of `CompositeImplicitAutograd`.

*Context:* PyTorch operations that have the `CompositeImplicitAutograd` dispatch do not
allow third party backends (e.g. XLA) to modify its implementation, since this dispatch
key has higher priority. When calling `_embedding_bag_backward` operation using XLA, a
dispatch error will be thrown, since PyTorch/XLA doesn't support sparse tensors.

*Problem:* `_embedding_bag_backward` has a `sparse` parameter that controls whether the
operation should return a sparse or dense tensor. However, at the moment, PyTorch/XLA does
not support sparse tensors. In order to fallback that execution to dense, i.e. change the
flag at runtime, we need to be able to modify its implementation.

*Solution:* we have changed the dispatch of `_embedding_bag_backward` to CPU and CUDA,
which allowed us to introduce our own kernel for it.

Additionally, this PR refactored the representation of its mode from constant integers
into an enum class. It also introduces two additional operators: `int == EmbeddingBagMode`
and `int != EmbeddingBagMode`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129691
Approved by: https://github.com/lezcano
2024-07-03 21:54:49 +00:00
eqy
f845a7a91a [cuDNN][SDPA] Remove TORCH_CUDNN_SDPA_ENABLED=1, enable cuDNN SDPA by default on H100 and 2nd on other archs >= sm80 (#125343)
Looks like one of the first failures seen is `test_causal_variants_compile_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` when `test_causal_variants_causal_variant_CausalVariant_LOWER_RIGHT_shape0_cuda` passes.

What seems interesting here is that the `torch.compile` version fails while the eager version passes. Not sure what the difference would be here...

Nevertheless, is there a recommended mechanism to skip cuDNN SDPA as a backend for this test? CC @drisspg

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125343
Approved by: https://github.com/Skylion007
2024-06-30 19:22:16 +00:00