Commit Graph

92 Commits

Author SHA1 Message Date
Yuanyuan Chen
f7ab8a2710 [1/N] Fix ruff warnings (#164333)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164333
Approved by: https://github.com/albanD
2025-10-01 16:48:32 +00:00
Yuanyuan Chen
3cda34ebde [2/N] Apply ruff UP035 check in torch files (#164054)
This is the result of applying the ruff `UP035` check.
`Callable` is imported from `collections.abc` instead of `typing`.
`TypeAlias` is also imported from `typing`.
This PR is the follow-up of #163947.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164054
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-09-29 03:35:32 +00:00
Xuehai Pan
db259bd6b8 [BE][12/16] fix typos in torch/ (#156602)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156602
Approved by: https://github.com/justinchuby, https://github.com/albanD
ghstack dependencies: #156318, #156320
2025-07-02 22:55:29 +00:00
Lucas Beyer
8a88c6e85a [nit] fix xavier init doc (#157100)
Remove part of the documentation that is irrelevant and confusing at best, probably a copy-paste mistake:

<img src="https://github.com/user-attachments/assets/77fa5734-5a5a-4f8d-80a5-bc3269668e07" width="500">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157100
Approved by: https://github.com/mikaylagawarecki
2025-06-27 19:13:40 +00:00
Xuehai Pan
596b418391 [BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
2025-06-14 11:27:04 +00:00
bubuss
31405a69fb [typing] Add missing type annotations to torch.nn.init module (#154504)
## Summary

Adds missing type annotations to `torch.nn.init` and removes `# mypy: allow-untyped-defs` since all functions are now properly typed.

## Changes

- Added missing type annotations to initialization functions in the module.
- Added missing typing imports: `Any`, `Callable`, `Union`
- Removed `# mypy: allow-untyped-defs` comment
- Create Literal types for kaiming initialization mode and nonlinearity.
- Created `__all__`

## Why

Better IDE support, catches type errors earlier, and brings the module up to PyTorch's typing standards. No runtime changes - purely additive typing improvements.

Tested with existing test suite and lintrunner.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154504
Approved by: https://github.com/Skylion007
2025-06-03 17:33:32 +00:00
cora-codes
40142978d7 Add type annotation to orthogonal_ (#154927)
Trivial charge, but I want pyright to stop yelling at me
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154927
Approved by: https://github.com/cyyever, https://github.com/Skylion007
2025-06-03 17:00:02 +00:00
Andrew Gu
2b43fab555 [DTensor] Added naive support for nn.init.orthogonal_ (#132104)
Try to unblock https://github.com/pytorch/pytorch/issues/131991

- `nn.init.orthogonal_` uses `tensor.new`, which is the legacy factory function. We change this to `tensor.new_empty` (empty is okay since it will be immediately followed by `.normal_()` to fill the tensor) so that it preserves `DTensor`-ness.
- `nn.init.orthogonal_` uses QR decomposition (`aten.linalg_qr.default`) and `torch.diag` (calling into `aten.diagonal_copy.default`). For simplicity, we use naive replicate strategies for now. `aten.diagonal_copy.default` could do something more sophisticated for sharded inputs, but I would rather defer that to later due to the complexity. For `orthogonal_` support specifically, since the result of the QR decomp will be replicated, the input to `aten.diagonal_copy.default` will be replicated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132104
Approved by: https://github.com/albanD, https://github.com/wanchaol
2024-07-30 21:55:09 +00:00
sradc
451fc029fe docs: note transposed weight initialisations (#130122)
Fixes #129834

Co-authored-by: mikaylagawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130122
Approved by: https://github.com/mikaylagawarecki
2024-07-19 15:23:03 +00:00
Xuehai Pan
f85d1e845a [BE] enable UFMT for torch/nn/*.py (#128593)
Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128593
Approved by: https://github.com/mikaylagawarecki
2024-06-23 16:05:13 +00:00
PyTorch MergeBot
aace8ffc00 Revert "[BE] enable UFMT for torch/nn/*.py (#128593)"
This reverts commit a87d82abd7.

Reverted https://github.com/pytorch/pytorch/pull/128593 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128593#issuecomment-2181562604))
2024-06-20 21:09:44 +00:00
Xuehai Pan
a87d82abd7 [BE] enable UFMT for torch/nn/*.py (#128593)
Part of #123062

- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128593
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #128596, #128594, #128592
2024-06-17 16:29:29 +00:00
Aaron Orenstein
27f9d3b0a1 Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843, #127844
2024-06-08 18:49:56 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
Antonio Kim
7fc292930c Add support for torch.Generator type in TorchScript (#110413)
- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
2023-11-21 23:07:21 +00:00
PyTorch MergeBot
252e68a83b Revert "Add support for torch.Generator type in TorchScript (#110413)"
This reverts commit 54493fe8c4.

Reverted https://github.com/pytorch/pytorch/pull/110413 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it is, unfortunately, still breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/110413#issuecomment-1811625557))
2023-11-15 00:51:23 +00:00
Antonio Kim
54493fe8c4 Add support for torch.Generator type in TorchScript (#110413)
- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
2023-11-13 23:18:14 +00:00
PyTorch MergeBot
9a28a7b498 Revert "Add support for torch.Generator type in TorchScript (#110413)"
This reverts commit 27e31ab6e8.

Reverted https://github.com/pytorch/pytorch/pull/110413 on behalf of https://github.com/PaliC due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/110413#issuecomment-1799003164))
2023-11-07 15:53:32 +00:00
Antonio Kim
27e31ab6e8 Add support for torch.Generator type in TorchScript (#110413)
- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
2023-11-06 21:27:02 +00:00
stan
57a3af900e Add suggested changes to init.py (#112864)
A follow-up of PR #112617  on issue #112596

Added suggested changes from the review.
-  More specific on the type of uniform and normal distribution used.

```py
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
    r"""Fill the input `Tensor` with values using a Xavier uniform distribution.

    The method is described in `Understanding the difficulty of training...
"""
```

```py
def kaiming_normal_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    r"""Fill the input `Tensor` with values using a Kaiming normal distribution.

    The method is described in `Delving deep into rectifiers: Surpassing...
"""
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112864
Approved by: https://github.com/kit1980
2023-11-03 22:46:48 +00:00
stanleyedward
12dab00173 Fix Docstring errors in init.py (#112617)
Fixes #112596

Fix docstring errors in init.py

### Before the change -> 38 errors
```
╭─user@pc ~/Path/to/pytorch  ‹fix/docstring_init›
╰─➤  pydocstyle torch/nn/init.py --count                                                                                                                                             127 ↵
torch/nn/init.py:1 at module level:
        D100: Missing docstring in public module
torch/nn/init.py:68 in public function `calculate_gain`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:123 in public function `uniform_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:123 in public function `uniform_`:
        D400: First line should end with a period (not 'm')
torch/nn/init.py:123 in public function `uniform_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:141 in public function `normal_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:141 in public function `normal_`:
        D400: First line should end with a period (not 'l')
torch/nn/init.py:141 in public function `normal_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:165 in public function `trunc_normal_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:165 in public function `trunc_normal_`:
        D400: First line should end with a period (not 'd')
torch/nn/init.py:165 in public function `trunc_normal_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:187 in public function `constant_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:203 in public function `ones_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:216 in public function `zeros_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:229 in public function `eye_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:229 in public function `eye_`:
        D400: First line should end with a period (not 'y')
torch/nn/init.py:229 in public function `eye_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:249 in public function `dirac_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:249 in public function `dirac_`:
        D400: First line should end with a period (not 'c')
torch/nn/init.py:249 in public function `dirac_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:311 in public function `xavier_uniform_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:311 in public function `xavier_uniform_`:
        D400: First line should end with a period (not 'd')
torch/nn/init.py:311 in public function `xavier_uniform_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:338 in public function `xavier_normal_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:338 in public function `xavier_normal_`:
        D400: First line should end with a period (not 'd')
torch/nn/init.py:338 in public function `xavier_normal_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:376 in public function `kaiming_uniform_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:376 in public function `kaiming_uniform_`:
        D400: First line should end with a period (not 'd')
torch/nn/init.py:376 in public function `kaiming_uniform_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:425 in public function `kaiming_normal_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:425 in public function `kaiming_normal_`:
        D400: First line should end with a period (not 'd')
torch/nn/init.py:425 in public function `kaiming_normal_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:462 in public function `orthogonal_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:462 in public function `orthogonal_`:
        D400: First line should end with a period (not 's')
torch/nn/init.py:462 in public function `orthogonal_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
torch/nn/init.py:507 in public function `sparse_`:
        D205: 1 blank line required between summary line and description (found 0)
torch/nn/init.py:507 in public function `sparse_`:
        D400: First line should end with a period (not 'e')
torch/nn/init.py:507 in public function `sparse_`:
        D401: First line should be in imperative mood (perhaps 'Fill', not 'Fills')
38
```

### After the change -> 0 errors
```
╭─user@pc ~/Path/to/pytorch  ‹fix/docstring_init*›
╰─➤  pydocstyle torch/nn/init.py --count
0
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112617
Approved by: https://github.com/mikaylagawarecki
2023-11-02 23:42:17 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Justin Chu
4cc1745b13 [BE] f-stringify torch/ and scripts (#105538)
This PR is a follow up on the pyupgrade series to convert more strings to use f-strings using `flynt`.

- https://docs.python.org/3/reference/lexical_analysis.html#f-strings
- https://pypi.org/project/flynt/

Command used:

```
flynt torch/ -ll 120
flynt scripts/ -ll 120
flynt tools/ -ll 120
```

and excluded `collect_env.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105538
Approved by: https://github.com/ezyang, https://github.com/malfet
2023-07-21 19:35:24 +00:00
Justin Chu
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
ts
74dc2a53f6 Thread generator through trunc_normal_ (#100810)
This will solve @albertz's issue as described in #98200 , threading the generator argument through the trunc_normal_ function. I'm still working on #99796 (and won't let it stall out), but this fix doesn't trigger any JIT issues, so I think it might be helpful to get it merged now.

Would be happy to iterate on this if there are any issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100810
Approved by: https://github.com/Skylion007, https://github.com/albanD
2023-05-12 01:04:08 +00:00
loganthomas
c848a777e8 DOC: Various typo fixes (#97095)
Various typos found while browsing documentation/source code.

Thank you for a wonderful deep-learning library!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97095
Approved by: https://github.com/mikaylagawarecki, https://github.com/kit1980
2023-03-20 20:46:04 +00:00
joncrall
b136f3f310 More doctest refinements. (#83317)
Follow up to #82797

Now that the doctests themselves are in a better state, we should be able to enable xdoctest on the CI so they stay that way.

@ezyang @vadimkantorov
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83317
Approved by: https://github.com/ezyang
2022-08-22 20:07:26 +00:00
joncrall
4618371da5 Integrate xdoctest - Rebased (#82797)
This is a new version of #15648 based on the latest master branch.

Unlike the previous PR where I fixed a lot of the doctests in addition to integrating xdoctest, I'm going to reduce the scope here. I'm simply going to integrate xdoctest, and then I'm going to mark all of the failing tests as "SKIP". This will let xdoctest run on the dashboards, provide some value, and still let the dashboards pass. I'll leave fixing the doctests themselves to another PR.

In my initial commit, I do the bare minimum to get something running with failing dashboards. The few tests that I marked as skip are causing segfaults. Running xdoctest results in 293 failed, 201 passed tests. The next commits will be to disable those tests. (unfortunately I don't have a tool that will insert the `#xdoctest: +SKIP` directive over every failing test, so I'm going to do this mostly manually.)

Fixes https://github.com/pytorch/pytorch/issues/71105

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82797
Approved by: https://github.com/ezyang
2022-08-12 02:08:01 +00:00
ProGamerGov
8def154e00 Fix multiple docstring type mistakes (#82474)
### Description

* Docstrings using `(tuple of ints)` shows up as `(tuple of python:ints)`, so I fixed them by making the `int` no longer plural. Example: https://pytorch.org/docs/stable/generated/torch.permute.html#torch.permute
* A docstring type in JIT had one of its types incorrectly highlighted as code. Example: https://pytorch.org/docs/stable/generated/torch.jit.script.html#torch.jit.script
* I found some docstring type usages of `string` that had not yet been converted to `str` after #82410
* Some docstrings incorrectly listed their defaults inside the docstring types.
* I also found a docstring that was missing its type

### Testing
No testing should be required.

---

In the developer guidelines, there should probably be standards listed for the docstring types.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82474
Approved by: https://github.com/albanD
2022-07-29 17:45:37 +00:00
Adam J. Stewart
dfde877c0b Add type hints for a few random functions/classes
Adds type hints for a few functions/classes that we use in [TorchGeo](https://github.com/microsoft/torchgeo).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74171
Approved by: https://github.com/jbschlosser, https://github.com/anjali411
2022-05-04 13:53:00 +00:00
PyTorch MergeBot
80fe96c860 Revert "Add type hints for a few random functions/classes"
This reverts commit cdb40eb528.

Reverted https://github.com/pytorch/pytorch/pull/74171 on behalf of https://github.com/zengk95
2022-04-21 21:07:15 +00:00
Adam J. Stewart
cdb40eb528 Add type hints for a few random functions/classes
Adds type hints for a few functions/classes that we use in [TorchGeo](https://github.com/microsoft/torchgeo).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74171
Approved by: https://github.com/jbschlosser
2022-04-21 20:09:40 +00:00
Jake Tae
b5b296c4cf Fix: Make nn.init.orthogonal_ no-op for empty input
Fixes #73503.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75553
Approved by: https://github.com/albanD
2022-04-13 15:48:31 +00:00
Brian Hirsh
f87f753bb9 avoiding adding some functions to the public python API before 1.11 release (#72543)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72543

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D34085724

Pulled By: bdhirsh

fbshipit-source-id: 941d5a90a6fa5328268d623e0e2b01577e4132ca
(cherry picked from commit 6676a0c79a)
2022-02-14 19:49:01 +00:00
Pritam Damania
8c505bbc86 Make ShardedTensor ctor more inline with torch.Tensor ctor (#72164)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72164

torch.Tensor ctor creates an empty tensor and this PR makes
ShardedTensor on par with that.

In particular we remove TensorInitParams and instead always a create an empty
tensor and then fill it in for things like ones, zeros, full etc. This is
inline with torch.ones etc. as well since even for those APIs we first create
an empty tensor and then fill it out.
ghstack-source-id: 148318045

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33934603

fbshipit-source-id: 5655bbd726f29e74600ebe9f33f9dc5952b528f4
(cherry picked from commit 78b301c78c)
2022-02-04 01:16:25 +00:00
Pritam Damania
b199e3c842 Provide functionality to write custom ShardedTensor ops. (#69874)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69874

We have a handful of ops supported for ShardedTensor via
``__torch_function__`` dispatch. However, we currently can't cover all torch
operators and having a way for users to extend this functionality will make
this functionality much more general.

In this PR, I've introduced a custom_sharded_op decorator which can be used to
register a custom sharded op implementation.
ghstack-source-id: 145841141

Test Plan: waitforbuildbot

Reviewed By: wanchaol

Differential Revision: D33078587

fbshipit-source-id: 5936b7ac25582e613653c19afa559219719ee54b
2021-12-16 12:40:13 -08:00
Bo Wang
3596e13d45 Add torch.nn.init.normal_ and torch.nn.init.kaiming_uniform_ ops to ShardedTensor (#67057)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67057

Extend ShardedTensor with torch.nn.init.[normal_, and kaiming_uniform_] ops
Follow up from https://github.com/pytorch/pytorch/pull/63997

Test Plan:
a) Unit Test
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v

or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#
s/uniform_/normal_ or kaiming_uniform_

Imported from OSS

Reviewed By: pritamdamania87

Differential Revision: D31845654

fbshipit-source-id: e7aedc0972539da59f7b84bbbf617caf6b206d52
2021-10-25 19:14:30 -07:00
Bo Wang
b6df043f1f Add torch.nn.init.uniform_ operator to ShardedTensor. (#63997)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63997

Use torch_function to extend torch.nn.init.uniform_
The Init is done in SPMD fashion. Note that ideally we want to aggregate sharded tensors into a global tensor, init it and reshard. It's fine to run it SPMD since uniform is I.I.D indepenent and identifically distributed.
Also enable unit test for test_linear.py for OSS test

Test Plan:
a) Unit Test
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v
(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_linear.py --v (before runs this command is no-op)

or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#

Imported from OSS

Reviewed By: pritamdamania87, anjali411

Differential Revision: D30563017

fbshipit-source-id: d1859f7682235bcb44515efc69ca92bc5e34fce1
2021-10-21 00:17:13 -07:00
lezcano
24087d07ca Deprecate QR (#57745)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57745

Reviewed By: bdhirsh

Differential Revision: D28318164

Pulled By: mruberry

fbshipit-source-id: b8e3cb9d7ab33f30c8653ec39f932a8af8bd2a50
2021-05-10 22:56:37 -07:00
beningodfrey4
df1dfd879e Fix errors when initializing Linear with 0 in_features (#56505)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56505

Reviewed By: malfet

Differential Revision: D27919590

Pulled By: jbschlosser

fbshipit-source-id: 462ca280051f63c31ff588c38a9e436116c0f336
2021-04-21 20:42:32 -07:00
mrTsjolder
a7c7fc96ff Add doc warnings for default SELU gain (#54057)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/24991 and provides the alternative solution suggested in https://github.com/pytorch/pytorch/issues/53694. Also related to https://github.com/pytorch/pytorch/issues/54055

Attempt to make people aware of the difference between paper and implementation of SELU gain.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/54057

Reviewed By: ailzhang

Differential Revision: D27292060

Pulled By: jbschlosser

fbshipit-source-id: e0e303595e6a7d05d11dfb68735e1839f55987a2
2021-03-25 11:21:02 -07:00
Jean Kossaifi
70a43425e0 Simplify init._calculate_fan_in_and_fan_out (#53522)
Summary:
This uses the shape of the tensor instead of directly indexing it. This is useful when extending PyTorch's tensor class, e.g. for lazy access. Since the `init` sub-module doesn't check for `torch_function`, it is not possibly to override its functions. Explicitly indexing the tensor will force a call to tensor() and reconstruct the full tensor/explicitly access the elements. Simply using the shape allows to avoid that.

Fixes https://github.com/pytorch/pytorch/issues/53540

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53522

Reviewed By: anjali411

Differential Revision: D26947794

Pulled By: jbschlosser

fbshipit-source-id: 80cd65efed16383f21363cee2eb404c9bc05971c
2021-03-10 11:57:17 -08:00
AJ San Joaquin
e9b369c25f Add SELU Activation to calculate_gain (#50664)
Summary:
Fixes #{[24991](https://github.com/pytorch/pytorch/issues/24991)}

I used a value of 0.75 as suggested in the forums by Thomas: https://discuss.pytorch.org/t/calculate-gain-tanh/20854/6

I verified that the value keeps the gradient stable for a 100-layer network.

Code to reproduce (from [jpeg729](https://discuss.pytorch.org/t/calculate-gain-tanh/20854/4)):
```python
import torch
import torch.nn.functional as F
import sys

a = torch.randn(1000,1000, requires_grad=True)
b = a
print (f"in: {a.std().item():.4f}")
for i in range(100):
    l = torch.nn.Linear(1000,1000, bias=False)
    torch.nn.init.xavier_normal_(l.weight, torch.nn.init.calculate_gain("selu"))
    b = getattr(F, 'selu')(l(b))
    if i % 10 == 0:
        print (f"out: {b.std().item():.4f}", end=" ")
        a.grad = None
        b.sum().backward(retain_graph=True)
        print (f"grad: {a.grad.abs().mean().item():.4f}")
```
Output:
```
in: 1.0008
out: 0.7968 grad: 0.6509
out: 0.3127 grad: 0.2760
out: 0.2404 grad: 0.2337
out: 0.2062 grad: 0.2039
out: 0.2056 grad: 0.1795
out: 0.2044 grad: 0.1977
out: 0.2005 grad: 0.2045
out: 0.2042 grad: 0.2273
out: 0.1944 grad: 0.2034
out: 0.2085 grad: 0.2464
```
I included the necessary documentation change, and it passes the _test_calculate_gain_nonlinear_ unittest.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/50664

Reviewed By: mruberry

Differential Revision: D25942217

Pulled By: ngimel

fbshipit-source-id: 29ff1be25713484fa7c516df71b12fdaecfb9af8
2021-01-18 23:01:18 -08:00
Richard Barnes
b89827b73f Drop unused imports (#49972)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49972

From
```
./python/libcst/libcst codemod remove_unused_imports.RemoveUnusedImportsWithGlean --no-format caffe2/
```

Test Plan: Standard sandcastle tests

Reviewed By: xush6528

Differential Revision: D25727352

fbshipit-source-id: 6b90717e161aeb1da8df30e67d586101d35d7d5f
2021-01-13 12:26:17 -08:00
Richard Barnes
f83d57f99e [Don't review] Clean up type annotations in caffe2/torch/nn (#50079)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50079

Test Plan: Sandcastle tests

Reviewed By: xush6528

Differential Revision: D25718694

fbshipit-source-id: f535fb879bcd4cb4ea715adfd90bbffa3fcc1150
2021-01-07 15:39:20 -08:00
Xiang Gao
20ac736200 Remove py2 compatible future imports (#44735)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44735

Reviewed By: mruberry

Differential Revision: D23731306

Pulled By: ezyang

fbshipit-source-id: 0ba009a99e475ddbe22981be8ac636f8a1c8b02f
2020-09-16 12:55:57 -07:00