Commit Graph

127 Commits

Author SHA1 Message Date
Xuehai Pan
4d7bf72d93 [BE][Easy] fix ruff rule needless-bool (SIM103) (#130206)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130206
Approved by: https://github.com/malfet
2024-07-14 08:17:52 +00:00
Michael Lazos
2129903aa3 Properly detect nested torch function args (#127496)
Dynamo was not detecting nested torch function classes in containers. This was due to pytree compatibility for variable trackers being removed.
Fixes https://github.com/pytorch/pytorch/issues/127174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127496
Approved by: https://github.com/anijain2305
2024-06-02 03:43:22 +00:00
chilli
392dc45597 Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124799
Approved by: https://github.com/drisspg
ghstack dependencies: #124444
2024-04-26 17:22:13 +00:00
PyTorch MergeBot
e913f77c60 Revert "Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)"
This reverts commit 9bccafc31c.

Reverted https://github.com/pytorch/pytorch/pull/124799 on behalf of https://github.com/clee2000 due to broke tests but only on crossref https://github.com/pytorch/pytorch/actions/runs/8841521519/job/24279075171, added no td label so itll actually run this time ([comment](https://github.com/pytorch/pytorch/pull/124799#issuecomment-2078530797))
2024-04-26 02:35:14 +00:00
chilli
9bccafc31c Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124799
Approved by: https://github.com/drisspg
ghstack dependencies: #124444
2024-04-26 01:02:28 +00:00
PyTorch MergeBot
678662a557 Revert "Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)"
This reverts commit acc4cbea39.

Reverted https://github.com/pytorch/pytorch/pull/124799 on behalf of https://github.com/jeanschmidt due to checking if this diff introduced regressions on linux-focal-py3.11-clang10 and linux-focal-py3.8-clang10 ([comment](https://github.com/pytorch/pytorch/pull/124799#issuecomment-2076756876))
2024-04-25 09:29:57 +00:00
chilli
acc4cbea39 Made FlexAttention rewrite getitem calls to use aten.index in score_mod (#124799)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124799
Approved by: https://github.com/drisspg
2024-04-25 06:19:55 +00:00
Xuehai Pan
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
cyy
fb90b4d4b2 [TorchGen] Use std::optional in generated code (#121454)
This PR changes TorchGen to generate std::optional.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121454
Approved by: https://github.com/ezyang
2024-03-29 14:11:09 +00:00
Aaron Gokaslan
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
Tobias Ringwald
460fc9da62 Disabled UserWarnings for some public functions in torch.overrides (#109890)
Fixes #109842.

This disables the implicit `UserWarning`s that were raised for deprecated `torch` attributes. The filtering was designed to be as specific as possible, in order to not filter any other warnings that may be raised.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109890
Approved by: https://github.com/ezyang
2023-09-23 20:40:04 +00:00
Matthew Hoffman
e40d6ae0a7 Improve torch.cuda.amp type hints (#108630)
Fixes #108629

1. Add the following to their modules' `__all__` so that pyright considers them to be publicly exported:
* [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast)
* [`torch.cuda.amp.GradScaler`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler)
* [`torch.cuda.amp.autocast`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast)
* [`torch.cuda.amp.custom_fwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd)
* [`torch.cuda.amp.custom_bwd`](https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_bwd)
2. Add `overload`s for `torch.cuda.amp.GradScaler.scale` to differentiate when a `torch.Tensor` is returned vs. an `Iterable[torch.Tensor]` is returned based on the type of the `outputs` parameter.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108630
Approved by: https://github.com/ezyang
2023-09-08 06:06:25 +00:00
dilililiwhy
5a9e82fa02 let torch.device be overrideable by TorchFunctionMode (#106514)
Fixes #103828
let torch.device be overrideable by TorchFunctionMode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106514
Approved by: https://github.com/ezyang
2023-08-04 10:47:43 +00:00
Justin Chu
73e1455327 [BE] Enable ruff's UP rules and autoformat test/ (#105434)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105434
Approved by: https://github.com/albanD
2023-07-19 20:36:06 +00:00
Guang Yang
68cb06c752 Make gen_annotated_args support kwargs (#98396)
This PR is to address the issue seeing in PR #97417 where the newly added op requires `kwargs`, however, currently tools/autograd/gen_annotated_fn_args.py does not support `kwargs`, only `func_args` are generated for test_overrides.py.

The PR adds a new field "is_kwargs" to each argument indicating whether it's a `kwargs` or not. See example:
```
annotated_args = {
    torch._C._VariableFunctions._cast_Byte: [{'is_kwarg_only': 'False', 'name': 'self', 'simple_type': 'Tensor'}],
    ...
```

The full comparison of the generated file `annotated_fn_args.py` can be found here:
  - **Before**: [P681991116](https://www.internalfb.com/phabricator/paste/view/P681991116)
  - **After**: [P681994218](https://www.internalfb.com/intern/paste/P681994218/)

Differential Revision: D44698310

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98396
Approved by: https://github.com/ezyang
2023-04-06 19:42:26 +00:00
soulitzer
d0abc31428 Remove unnecessary retain_grad call from gradcheck (#96923)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96923
Approved by: https://github.com/albanD
2023-03-27 13:38:28 +00:00
Aaron Gokaslan
8fce9a09cd [BE]: pyupgrade Python to 3.8 - imports and object inheritance only (#94308)
Apply parts of pyupgrade to torch (starting with the safest changes).
This PR only does two things: removes the need to inherit from object and removes unused future imports.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94308
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-07 21:10:56 +00:00
Pearu Peterson
4a4520e74b Retire unsafe sparse tensor constructors in Python API (#91331)
This PR removes sparse tensor constructor functions `torch._sparse_coo/csr/csc/bsr/bsc/compressed_tensor_unsafe(...)` as unneeded. The equivalent functionality is provided via `torch.sparse_coo/csr/csc/bsr/bsc/compressed_tensor(..., check_invariants=False)`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91331
Approved by: https://github.com/amjames, https://github.com/cpuhrsch
2023-01-18 08:55:22 +00:00
samdow
b8252e07c7 [Reland] add DisableTorchFunction that matches DisableTorchDispatch (#88219) (#92012)
Reland of #88219

Closes #87990. This implements a new disable guard that matches DisableTorchDispatch (disables all subclasses and modes)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92012
Approved by: https://github.com/albanD
2023-01-12 01:27:47 +00:00
Edward Z. Yang
333540a458 Reland "Add torch.utils.device_mode" (#91796)
Original PR https://github.com/pytorch/pytorch/pull/91525

Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91796
Approved by: https://github.com/albanD
2023-01-09 20:57:12 +00:00
PyTorch MergeBot
9b415240d4 Revert "Reland "Add torch.utils.device_mode" (#91796)"
This reverts commit 81b5eff3c3.

Reverted https://github.com/pytorch/pytorch/pull/91796 on behalf of https://github.com/huydhn due to This breaks trunk with the following failed test https://hud.pytorch.org/failure/test_jit_save%2CTestTracer
2023-01-09 04:45:47 +00:00
Edward Z. Yang
81b5eff3c3 Reland "Add torch.utils.device_mode" (#91796)
Original PR https://github.com/pytorch/pytorch/pull/91525

Signed-off-by: Edward Z. Yang <ezyangfb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91796
Approved by: https://github.com/albanD
2023-01-08 03:44:56 +00:00
PyTorch MergeBot
f571ae4fdb Revert "Make torch.device usable as a context manager (#91525)"
This reverts commit 619d52a5d2.

Reverted https://github.com/pytorch/pytorch/pull/91525 on behalf of https://github.com/mehtanirav due to Internal breakages
2023-01-05 21:34:50 +00:00
Samantha Andow
a7749ae177 [reland] rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218) (#89221)
Summary: First half of #87990. This doesn't change any of the behavior and is just a rename

#88218 got reverted for internal breakages. This is the reland of started from internal

Differential Revision:
D41268423

LaMa Project: L1098534

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89221
Approved by: https://github.com/meliy-meyada, https://github.com/zou3519
2023-01-04 18:32:49 +00:00
Edward Z. Yang
619d52a5d2 Make torch.device usable as a context manager (#91525)
Fixes https://github.com/pytorch/pytorch/issues/82296
Fixes https://github.com/pytorch/pytorch/issues/27878
Fixes https://github.com/pytorch/pytorch/issues/260

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91525
Approved by: https://github.com/albanD
2023-01-04 01:32:00 +00:00
Pearu Peterson
b87682f555 Fix gradcheck for CSR and CSC inputs. (#89786)
Partially fix-es https://github.com/pytorch/pytorch/issues/87085

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89786
Approved by: https://github.com/albanD
2022-12-02 12:35:20 +00:00
PyTorch MergeBot
ba4d5aae06 Revert "rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)"
This reverts commit 7f28be10e5.

Reverted https://github.com/pytorch/pytorch/pull/88218 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901
2022-11-11 19:13:05 +00:00
PyTorch MergeBot
4e5d7afe84 Revert "add DisableTorchFunction that matches DisableTorchDispatch (#88219)"
This reverts commit c0ecce15b5.

Reverted https://github.com/pytorch/pytorch/pull/88219 on behalf of https://github.com/izaitsevfb due to BC-breaking change, D41211901
2022-11-11 19:08:30 +00:00
samdow
c0ecce15b5 add DisableTorchFunction that matches DisableTorchDispatch (#88219)
Closes #87990. This implements a new disable guard that matches DisableTorchDispatch (disables all subclasses and modes)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88219
Approved by: https://github.com/ezyang
2022-11-10 14:51:13 +00:00
samdow
7f28be10e5 rename DisableTorchFunction to DisableTorchFunctionSubclass (#88218)
First half of #87990. This doesn't change any of the behavior and is just a rename

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88218
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-11-10 14:51:13 +00:00
Peter Bell
eb3f975c6e Fix segfault in has_torch_function (#88559)
Fixes #83908

`PySequence_Fast` may return `NULL` to indicate an error was raised, in which
case `sequence_has_torch_function` will dereference a null pointer.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88559
Approved by: https://github.com/ezyang, https://github.com/Skylion007, https://github.com/hameerabbasi
2022-11-07 23:48:39 +00:00
samdow
169ec120ef [Modes] refactor modes to only use a stack in cpp (#86458)
Refactors the mode code to only have the C++ mode stack and not the "C++ mode" like we originally had. This also simplifies the mode logic in a number of places
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86458
Approved by: https://github.com/zou3519
2022-10-21 19:18:23 +00:00
samdow
a106611055 [Modes] fix handle_torch_funcion logic (#85707)
Fixes #85696. I didn't totally get what was happening in handle_torch_function and so was trying to recreate the original logic instead of follow what the C++ is doing. This fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85707
Approved by: https://github.com/ezyang
2022-09-27 18:35:51 +00:00
samdow
18d8c548f4 [Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-27 01:04:35 +00:00
Brian Hirsh
4a2d2e5e40 Change API type Tensor[] for structured kernels. (#73350)
Partially fixes: #66328

This PR:
- adds support for `ITensorList` to the dispatcher for:
  - computing the dispatch key
  - boxing and unboxing `ITensorList`
- modified the codegen for structured kernels:
  - codegen APIs use `ITensorList` instead of `ArrayRef<Tensor>`

**Changes summary:**

- Signature changes due to the different APIs:
  - dispatcher API (e.g. `BatchingRegistrations.cpp`)
  - C++ API (e.g. `TensorShape.cpp`)
- Miscelaneous functions used by codegen'd functions (e.g. `FunctionalTensorWrapper.*`)
- Dispatcher changes for handling `ITensorList` correctly (e.g. `DispatchKeyExtractor.h`)
- Signature changes of `at::cat` due to the need of `const` inside `TensorBody.h`
- Forward declarations of `ITensorList` (e.g. `MethodOperators.h`)
- Codegen changes, special casing structured kernels (e.g. `gen.py`)

**Short description of structured kernels special casing:**

I introduced, mainly, 5 types of changes to the codegen for generating code depending on
whether the kernel is structured or not:

1. Added a `structured_type_override` flag to the `argument_type` function definition of
the affected APIs (mainly the dispatcher and C++ APIs).
  - `api/cpp.py`, `api/dispatcher.py`, `api/native.py`
2. Added a `structured_type_override` member to the signature
classes (e.g. `CppSignature`), since `FunctionSchema` doesn't really know whether the
function is structured or not
  - `api/types.py`
3. Added a `part_of_structured_group` to `NativeFunction` class, which is just a
convenient function to forward to `structured_type_override` wherever needed
  - `model.py`
4. Appropriately changed the rest of the codegen, whenever it used either the signature
classes or the `arguments` function directly
5. Added a check for `const ITensorList&` type wherever there was a check for `TensorList`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73350
Approved by: https://github.com/bdhirsh
2022-09-26 21:46:38 +00:00
PyTorch MergeBot
f534b2c627 Revert "Remove split functional wrapper (#74727)"
This reverts commit a58876ace7.

Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/seemethere due to Fails internal use cases, might extend out to external use cases as well. Need to assess overall impact of this change more widely
2022-08-10 19:45:23 +00:00
albanD
e4ea751810 Fix hash for Tensor subclasses (#83174)
Fixes https://github.com/pytorch/pytorch/issues/82832
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83174
Approved by: https://github.com/ezyang
2022-08-10 19:23:56 +00:00
Peter Bell
a58876ace7 Remove split functional wrapper (#74727)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74727
Approved by: https://github.com/albanD, https://github.com/khabinov
2022-08-10 17:57:48 +00:00
Edward Z. Yang
a61c96492b Add EnableTorchFunction (#82647)
If you DisableTorchFunction, as is done in the default __torch_function__
implementation, if you want to reentrantly use TorchFunction (e.g., to
trace FX proxies), you have to be able to turn it back on.
enable_reentrant_dispatch does not work in this case because by the time
we snapshot TLS, torch function is already disabled.

Differential Revision: [D38354504](https://our.internmc.facebook.com/intern/diff/D38354504/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D38354504/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82647
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-08-07 03:04:46 +00:00
Fabio Rocha
fd84c458f4 Add torch.unflatten and improve its docs (#81399)
unflatten now has a free function version in torch.flatten in addition to
    the method in torch.Tensor.flatten.

    Updated docs to reflect this and polished them a little.
    For consistency, changed the signature of the int version of unflatten in
    native_functions.yaml.

    Some override tests were failing because unflatten has unusual
    characteristics in terms of the .int and .Dimname versions having
    different number of arguments so this required some changes
    to test/test_override.py

    Removed support for using mix of integer and string arguments
    when specifying dimensions in unflatten.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81399
Approved by: https://github.com/Lezcano, https://github.com/ngimel
2022-07-29 15:02:42 +00:00
samdow
2ac24675cc get rid of push_torch_{dispatch, function}_mode (#78215)
Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`

This removes the first API (which is older and private so we don't need to go through a deprecation cycle)

There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.

EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
2022-07-22 18:56:37 +00:00
Edward Z. Yang
d4f065d261 Return mode object from __enter__ (#80998)
This makes `with Mode() as m:` work.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80998
Approved by: https://github.com/samdow
2022-07-12 23:22:26 +00:00
PyTorch MergeBot
7f3677d723 Revert "Remove split functional wrapper (#74727)"
This reverts commit cc3126083e.

Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/mehtanirav due to Breaking multiple internals builds and tests
2022-07-11 18:29:45 +00:00
Peter Bell
cc3126083e Remove split functional wrapper (#74727)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74727
Approved by: https://github.com/albanD
2022-07-08 19:21:22 +00:00
samdow
5e926aafab add utils for checking that all modes are in the same scope and finding the outermost mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78847

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-10 19:31:05 +00:00
samdow
3734fcc8f8 add ability to push a mode if the current mode is an ancestor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78822

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-10 18:27:04 +00:00
samdow
184e0065b3 add better error message for class method
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78821

Approved by: https://github.com/ezyang
2022-06-06 13:31:32 +00:00
Edward Z. Yang
7860ce5b79 Fix tests that were never running, add a new test
Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78843

Approved by: https://github.com/samdow
2022-06-04 01:09:52 +00:00
samdow
aa06d05297 enable with semantics
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78214

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-01 21:14:45 +00:00
Edward Z. Yang
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00