Commit Graph

31 Commits

Author SHA1 Message Date
Aaron Orenstein
5b5766665d PEP585 update - torch/_C torch/_decomp torch/_lazy torch/_library torch/_numpy torch/_prims torch/_refs torch/_strobelight (#145102)
See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145102
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145105
2025-01-18 20:47:12 +00:00
Xuehai Pan
e7eeee473c [BE][Easy][14/19] enforce style for empty lines in import segments in torch/_[a-c]*/ and torch/_[e-h]*/ and torch/_[j-z]*/ (#129765)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129765
Approved by: https://github.com/ezyang
2024-07-31 10:42:50 +00:00
Aaron Orenstein
4a5a87168e [BE] typing for decorators - _prims_common/wrappers (#131567)
See #131429
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131567
Approved by: https://github.com/oulgen, https://github.com/zou3519
2024-07-25 14:35:13 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
blorange-amd
df9b44436a [ROCm] Enable float16/complex32 fft tests on ROCm (#117296)
This PR is to enable float16/complex32 fft tests on ROCm.
Sample results are attached here:
[test_spectral_ops_results.log](https://github.com/pytorch/pytorch/files/13908533/test_spectral_ops_results.log)

test_decomp::TestDecompCUDA::test_comprehensive_fft*
test_decomp::TestDecompCUDA::test_quick_fft*
test_jit_fuser_te::TestNNCOpInfoCUDA::test_nnc_correctness_fft*
test_meta::TestMetaCUDA::test_dispatch_meta_inplace_fft*
test_meta::TestMetaCUDA::test_dispatch_meta_outplace_fft*
test_meta::TestMetaCUDA::test_dispatch_symbolic_meta_inplace_fft*
test_meta::TestMetaCUDA::test_dispatch_symbolic_meta_outplace_fft*
test_meta::TestMetaCUDA::test_meta_inplace_fft*
test_meta::TestMetaCUDA::test_meta_outplace_fft*
test_ops::TestCommonCUDA::test_complex_half_reference_testing_fft*
test_ops::TestCommonCUDA::test_python_ref__refs_fft*
test_ops::TestCommonCUDA::test_python_ref_executor__refs_fft*
test_ops::TestCommonCUDA::test_python_ref_meta__refs*
test_ops::TestCommonCUDA::test_python_ref_torch_fallback__refs_fft*
test_schema_check::TestSchemaCheckModeOpInfoCUDA::test_schema_correctness_fft*
test_spectral_ops::TestFFTCUDA::test_empty_fft__refs_fft*
test_spectral_ops::TestFFTCUDA::test_empty_fft_fft*
test_spectral_ops::TestFFTCUDA::test_fft_half_and_chalf_not_power_of_two_error__refs_fft*
test_spectral_ops::TestFFTCUDA::test_fft_half_and_chalf_not_power_of_two_error_fft*
test_spectral_ops::TestFFTCUDA::test_fft_round_trip_cuda*
test_spectral_ops::TestFFTCUDA::test_fft_type_promotion_cuda*
test_spectral_ops::TestFFTCUDA::test_fftn_round_trip_cuda*
test_spectral_ops::TestFFTCUDA::test_hfftn_cuda_float16
test_spectral_ops::TestFFTCUDA::test_ihfftn_cuda_float16
test_utils::TestDeviceUtilsCUDA::test_device_mode_ops_fft

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117296
Approved by: https://github.com/pruthvistony, https://github.com/malfet
2024-02-13 22:35:32 +00:00
Catherine Lee
4f5785b6b3 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Co-authored-by: Catherine Lee <csl@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 21:07:01 +00:00
PyTorch MergeBot
40ece2e579 Revert "Enable possibly-undefined error code (#118533)"
This reverts commit 4f13f69a45.

Reverted https://github.com/pytorch/pytorch/pull/118533 on behalf of https://github.com/clee2000 due to sorry i'm trying to figure out a codev merge conflict, if this works i'll be back to rebase and merge ([comment](https://github.com/pytorch/pytorch/pull/118533#issuecomment-1917695185))
2024-01-30 19:00:34 +00:00
Edward Z. Yang
4f13f69a45 Enable possibly-undefined error code (#118533)
Fixes https://github.com/pytorch/pytorch/issues/118129

Suppressions automatically added with

```
import re

with open("error_file.txt", "r") as f:
    errors = f.readlines()

error_lines = {}
for error in errors:
    match = re.match(r"(.*):(\d+):\d+: error:.*\[(.*)\]", error)
    if match:
        file_path, line_number, error_type = match.groups()
        if file_path not in error_lines:
            error_lines[file_path] = {}
        error_lines[file_path][int(line_number)] = error_type

for file_path, lines in error_lines.items():
    with open(file_path, "r") as f:
        code = f.readlines()
    for line_number, error_type in sorted(lines.items(), key=lambda x: x[0], reverse=True):
        code[line_number - 1] = code[line_number - 1].rstrip() + f"  # type: ignore[{error_type}]\n"
    with open(file_path, "w") as f:
        f.writelines(code)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118533
Approved by: https://github.com/Skylion007, https://github.com/zou3519
2024-01-30 05:08:10 +00:00
isdanni
2f53085f3f [BE] Enable Ruff's Flake8 PYI030 (#111103)
Enable [unnecessary-literal-union (PYI030)](https://docs.astral.sh/ruff/rules/unnecessary-literal-union/)

Link: #110950
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111103
Approved by: https://github.com/albanD
2023-10-12 13:31:59 +00:00
Peter Bell
d796518485 [refs] Fix size check from #108360 (#109083)
PR #108360 uses the same default `last_dim_size` formula from complex-to-real (C2R) transforms for
complex-to-complex (C2C) and real-to-complex (R2C). However, this is not correct because for C2R
the input is only half the size of the full tensor, which is not the case for C2C and C2R.

This error is mostly benign since `last_dim_size` was only used for the `>= 1` condition which is
almost always met anyway.

For this PR I now use it as the argument to `_apply_norm` which makes it load-bearing for correctness
and so is thoroughly tested now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109083
Approved by: https://github.com/lezcano
2023-09-27 23:59:29 +00:00
ekamiti
0f88d93b10 decomposition spectral ops fixes (#108360)
Fixes https://github.com/pytorch/pytorch/issues/105986, https://github.com/pytorch/pytorch/issues/108204, https://github.com/pytorch/pytorch/issues/108205

Fix all issues flagged when making changes for https://github.com/pytorch/pytorch/pull/107421

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108360
Approved by: https://github.com/ezyang
2023-09-09 04:48:09 +00:00
Kurt Mohler
ee83c646bb Replace _prims_common.check with torch._check* (#103240)
This relands most of the changes from #102219 which were backed out by #103128. However, instead of removing `_prims_common.check`, it adds a warning and a comment mentioning that it will be removed in the future and `torch._check*` should be used instead. As mentioned in https://github.com/pytorch/pytorch/pull/103128#pullrequestreview-1466414415, `_prims_common.check` cannot yet be removed because of some internal usage

Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103240
Approved by: https://github.com/albanD
2023-06-21 00:46:17 +00:00
Ivan Zaitsev
821493715c Back out "Remove check from _prims_common, replace with torch._check* (#102219)", Back out "Forwatd fix for D46427687" (#103128)
Test Plan: revertitparrot

Reviewed By: malfet

Differential Revision: D46506433

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103128
Approved by: https://github.com/malfet
2023-06-07 01:41:41 +00:00
Kurt Mohler
a84bb2709a Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-03 02:23:21 +00:00
PyTorch MergeBot
a7efa0ce35 Revert "Remove check from _prims_common, replace with torch._check* (#102219)"
This reverts commit fb79d43649.

Reverted https://github.com/pytorch/pytorch/pull/102219 on behalf of https://github.com/malfet due to Broke lint, see https://github.com/pytorch/pytorch/actions/runs/5158949959/jobs/9293466925 ([comment](https://github.com/pytorch/pytorch/pull/102219#issuecomment-1574245414))
2023-06-02 20:00:48 +00:00
Kurt Mohler
fb79d43649 Remove check from _prims_common, replace with torch._check* (#102219)
Part of #72948

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102219
Approved by: https://github.com/lezcano, https://github.com/albanD
2023-06-02 19:13:45 +00:00
Xuehai Pan
69e0bda999 [BE] Import Literal, Protocol, and Final from standard library typing as of Python 3.8+ (#94490)
Changes:

1. `typing_extensions -> typing-extentions` in dependency. Use dash rather than underline to fit the [PEP 503: Normalized Names](https://peps.python.org/pep-0503/#normalized-names) convention.

```python
import re

def normalize(name):
    return re.sub(r"[-_.]+", "-", name).lower()
```

2. Import `Literal`, `Protocal`, and `Final` from standard library as of Python 3.8+
3. Replace `Union[Literal[XXX], Literal[YYY]]` to `Literal[XXX, YYY]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94490
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-09 19:17:49 +00:00
Nikita Shulga
fd3a7264ae [MPS] Add group_norm[fwd+backward] and mean_var (take 2) (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Add `__all__` to `torch/backends/mps/__init__.py` as well as alias all imports as private

Add `TestNNMPS.test_group_norm_backward` that validates no NaNs are generated during the backward pass

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-22 08:54:37 +00:00
PyTorch MergeBot
645eda0a00 Revert "[MPS] Add group_norm[fwd+backward] and mean_var (#91190)"
This reverts commit 371716eb36.

Reverted https://github.com/pytorch/pytorch/pull/91190 on behalf of https://github.com/kit1980 due to Broke test_correct_module_names because of underscore _ops
2022-12-21 19:37:43 +00:00
Nikita Shulga
371716eb36 [MPS] Add group_norm[fwd+backward] and mean_var (#91190)
Use Prims to implement group_norm, group_norm_backward and mean_var

Use `torch._ops.ops` instead of `torch.ops` in numerous subpackages in
order to be able to make them importable from `torch/backend/mps/__init__.py` as this alias is defined in
15af4b1cee/torch/__init__.py (L1095)
is executed last during init process.

Depends on https://github.com/pytorch/pytorch/pull/91203

Fixes https://github.com/pytorch/pytorch/issues/88331
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91190
Approved by: https://github.com/albanD
2022-12-21 17:33:27 +00:00
Nikita Shulga
c8546c930f [BE] Use aten global in torch._refs (#91189)
Similar to pattern used in `torch._decomp`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91189
Approved by: https://github.com/ngimel
2022-12-21 02:28:51 +00:00
Peter Bell
ac19c5be82 FFT: disable dimension wrapping for scalar tensors (#89234)
Fixes #88985

By default, `maybe_wrap_dim` allows through `dim=0` or `dim=-1`
for scalar tensors which leads to an invalid dimension being used to
index into `tensor.sizes()` as in the code sample from the issue.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89234
Approved by: https://github.com/mruberry
2022-11-23 21:55:00 +00:00
Khushi
a3f8495b84 [primTorch fix] use _maybe_convert_to_dtype (#85163)
Fixes #84561

- [x] fix lint tests

cc: @Lezcano!!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85163
Approved by: https://github.com/lezcano, https://github.com/mruberry
2022-10-31 17:08:55 +00:00
Huy Do
12cb26509a Apply ufmt to torch internal (#81643)
This is a big bang PR, merge conflicts are probably expected and will be addressed at merge.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81643
Approved by: https://github.com/ezyang
2022-07-22 02:19:50 +00:00
Horace He
a5fb41e3d3 Revert "Revert "Refactored prim utils into _prims_utils folder (#81746)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81746
Approved by: https://github.com/anijain2305, https://github.com/Krovatkin
2022-07-20 23:43:57 +00:00
PyTorch MergeBot
e43a02c314 Revert "Refactored prim utils into _prims_utils folder (#81088)"
This reverts commit 80231d0a72.

Reverted https://github.com/pytorch/pytorch/pull/81088 on behalf of https://github.com/jeanschmidt due to breaking internal tests
2022-07-19 19:56:41 +00:00
Horace He
80231d0a72 Refactored prim utils into _prims_utils folder (#81088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81088
Approved by: https://github.com/ngimel
2022-07-19 03:55:51 +00:00
Peter Bell
443b13fa23 [primTorch] Implement fftshift and ifftshift (#80737)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80737
Approved by: https://github.com/mruberry
2022-07-15 15:13:46 +00:00
Peter Bell
c0ff72b3ab [primTorch] Implement two-dimensional fft transforms (#80736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80736
Approved by: https://github.com/mruberry
2022-07-15 15:13:46 +00:00
Peter Bell
353180e1bf [primTorch] Implement n-dimensional fft transforms (#80571)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80571
Approved by: https://github.com/mruberry
2022-07-15 15:13:45 +00:00
Peter Bell
bf36d8b987 [primTorch] Implement one-dimensional fft transforms (#80570)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80570
Approved by: https://github.com/mruberry
2022-07-15 15:13:43 +00:00