Commit Graph

134 Commits

Author SHA1 Message Date
Aaron Gokaslan
b7d08defe9 [BE]: Type previously untyped decorators (#153726)
This fixes decorator typing which unmasks a lot of typing issues in the codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153726
Approved by: https://github.com/albanD
2025-05-21 15:56:19 +00:00
Aaron Gokaslan
ffd49d538e [BE][Ez]: Improve typing in torch/modules/container.py (#153728)
Adds some missing type annotations

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153728
Approved by: https://github.com/albanD
2025-05-21 07:15:00 +00:00
zeshengzong
d457b4492d Optimize Sequential methods description (#147304)
Fixes #146892

Add methods description and examples for [`Sequential` document](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)

## Test Result

### Before

![image](https://github.com/user-attachments/assets/3121a06f-02ed-4362-ad0a-f055bb43d469)

### After

![image](https://github.com/user-attachments/assets/66f6bb55-5298-4062-8f7f-7a7f4c1e16d9)
![image](https://github.com/user-attachments/assets/a5275a4c-4214-4518-b7a2-dff21954f368)
![image](https://github.com/user-attachments/assets/9c40d1fb-114a-4d14-a3c4-1143a131660e)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147304
Approved by: https://github.com/mikaylagawarecki
2025-05-02 19:18:58 +00:00
Aaron Gokaslan
cccfc146fe [BE][Easy]: Simplify ModuleList reversed method (#151673)
Removes unnecessary list calls now that we are in Python 3.9 and KeyViews implement reversed directly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151673
Approved by: https://github.com/albanD
2025-04-18 18:39:32 +00:00
zeshengzong
1a48382a4c [Easy] Optimize container.py typing (#151653)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151653
Approved by: https://github.com/albanD
2025-04-18 17:33:43 +00:00
cyy
d87aad6877 [5/N] Apply Ruff fixes and pyupgrade to Python 3.9 (#144205)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144205
Approved by: https://github.com/albanD
2025-01-15 04:00:47 +00:00
Matthew Hoffman
7ea8374c0e nn.ModuleList.__getitem__ overloads (#132834)
Overloads so that you can get more specific type info based on how you are indexing.

```python
from torch import nn

module_list = nn.ModuleList(32 * [nn.Linear(2, 2)])

# before:
reveal_type(module_list[0])  # Type of "module_list[0]" is "Module | ModuleList"
reveal_type(module_list[:1])  # Type of "module_list[: 1]" is "Module | ModuleList"

# now:
reveal_type(module_list[0])  # Type of "module_list[0]" is "Module"
reveal_type(module_list[:1])  # Type of "module_list[: 1]" is "ModuleList"
```
Co-authored-by: Skylion007 <Skylion007@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132834
Approved by: https://github.com/Skylion007, https://github.com/albanD
2024-08-07 19:25:23 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
PyTorch MergeBot
609447a626 Revert "[BE] typing for decorators - _jit_internal (#131573)"
This reverts commit f0f20f7e97.

Reverted https://github.com/pytorch/pytorch/pull/131573 on behalf of https://github.com/clee2000 due to breaking lint internally D60265575 ([comment](https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359))
2024-07-28 03:29:32 +00:00
Aaron Orenstein
f0f20f7e97 [BE] typing for decorators - _jit_internal (#131573)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131573
Approved by: https://github.com/oulgen, https://github.com/zou3519
ghstack dependencies: #131568, #131569, #131570, #131571, #131572
2024-07-25 22:24:19 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Xuehai Pan
62ccf6d7cd [BE] enable UFMT for torch/nn/modules (#128594)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128594
Approved by: https://github.com/mikaylagawarecki
2024-06-23 05:37:57 +00:00
Mashrur Morshed
9103b40a47 Fix small typo in docstring in ParameterList (#129193)
In the docstring of `nn.ParameterList`, ParameterDict.append/extend was being used, which is most likely a typo.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129193
Approved by: https://github.com/mikaylagawarecki
2024-06-21 20:53:52 +00:00
PyTorch MergeBot
d4022b4658 Revert "[BE] enable UFMT for torch/nn/modules (#128594)"
This reverts commit 95ac2d6482.

Reverted https://github.com/pytorch/pytorch/pull/128594 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/128594#issuecomment-2181788935))
2024-06-21 00:50:08 +00:00
Xuehai Pan
95ac2d6482 [BE] enable UFMT for torch/nn/modules (#128594)
Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128594
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #128596
2024-06-17 16:29:25 +00:00
Aaron Orenstein
27f9d3b0a1 Flip default value for mypy disallow_untyped_defs [8/11] (#127845)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127845
Approved by: https://github.com/oulgen
ghstack dependencies: #127842, #127843, #127844
2024-06-08 18:49:56 +00:00
Xuehai Pan
67ef2683d9 [BE] wrap deprecated function/class with typing_extensions.deprecated (#127689)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

Resolves #126888

- #126888

This PR is split from PR #126898.

- #126898

------

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127689
Approved by: https://github.com/Skylion007
2024-06-02 12:30:43 +00:00
PyTorch MergeBot
033e733021 Revert "[BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)"
This reverts commit 749a132fb0.

Reverted https://github.com/pytorch/pytorch/pull/126898 on behalf of https://github.com/fbgheith due to switching typing-extensions=4.3.0 to 4.9.0 causes internal failure ([comment](https://github.com/pytorch/pytorch/pull/126898#issuecomment-2142884456))
2024-05-31 19:47:24 +00:00
Xuehai Pan
749a132fb0 [BE] wrap deprecated function/class with typing_extensions.deprecated (#126898)
Use `typing_extensions.deprecated` for deprecation annotation if possible. Otherwise, add `category=FutureWarning` to `warnings.warn("message")` if the category is missing.

Note that only warnings that their messages contain `[Dd]eprecat(ed|ion)` are updated in this PR.

UPDATE: Use `FutureWarning` instead of `DeprecationWarning`.

Resolves #126888

- #126888

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126898
Approved by: https://github.com/albanD
2024-05-29 12:09:27 +00:00
markstur
5540d276ce Fix docstring errors in container.py, _functions.py, transformer.py, comm.py, parallel_apply.py, data_parallel.py, scatter_gather.py (#113250)
Fix docstring errors in container.py, _functions.py, transformer.py, comm.py, parallel_apply.py, data_parallel.py, scatter_gather.py

Fixes #112603

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113250
Approved by: https://github.com/mikaylagawarecki
2023-11-10 21:07:25 +00:00
isdanni
382327bd0e [BE] Enable Ruff's Flake8 PYI034 (#111105)
Enable [non-self-return-type (PYI034)](https://docs.astral.sh/ruff/rules/non-self-return-type/#non-self-return-type-pyi034)

Link: #110950

**EDIT**: to newly added reviewers, please ignore the request, it's due to a rebase error 😅

Pull Request resolved: https://github.com/pytorch/pytorch/pull/111105
Approved by: https://github.com/Skylion007
2023-10-13 21:19:53 +00:00
Aaron Gokaslan
660e8060ad [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-22 23:16:38 +00:00
PyTorch MergeBot
d59a6864fb Revert "[BE]: Update ruff to 0.285 (#107519)"
This reverts commit 88ab3e4322.

Reverted https://github.com/pytorch/pytorch/pull/107519 on behalf of https://github.com/ZainRizvi due to Sorry, but this PR breaks internal tests. @ezyang, can you please hep them get unblocked? It seems like one of the strings was prob accidentally modified ([comment](https://github.com/pytorch/pytorch/pull/107519#issuecomment-1688833480))
2023-08-22 19:53:32 +00:00
Aaron Gokaslan
88ab3e4322 [BE]: Update ruff to 0.285 (#107519)
This updates ruff to 0.285 which is faster, better, and have fixes a bunch of false negatives with regards to fstrings.

I also enabled RUF017 which looks for accidental quadratic list summation. Luckily, seems like there are no instances of it in our codebase, so enabling it so that it stays like that. :)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107519
Approved by: https://github.com/ezyang
2023-08-20 01:36:18 +00:00
Alexander Pivovarov
28a4fc8d8a Fixe some typos (#105869)
### Description:
- Fixes for typos in comments
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105869
Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007
2023-07-26 16:23:57 +00:00
Justin Chu
79c5e33349 [BE] Enable ruff's UP rules and autoformat nn/ mps/ and torch/ (#105436)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105436
Approved by: https://github.com/malfet, https://github.com/albanD
2023-07-21 07:38:46 +00:00
Nikita Shulga
5837e95d30 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`

Unrelated, to bypass CI failures due to the gcc9 dependency update in Ubuntu-18.04:
- Add hack to squash older libstdc++ from conda environment in favor one from OS to `.ci/docker/install_conda.sh`
- Update bazel cuda builds to focal, as with libstdc++-6.0.32 bazel builds loose the ability to catch exceptions (probably because they link with cupti statically, but I could not found where it is done)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-15 20:30:20 +00:00
PyTorch MergeBot
15fd1ea118 Revert "[Reland] Update mypy to 1.4.1 (#105227)"
This reverts commit c9c4f8efc3.

Reverted https://github.com/pytorch/pytorch/pull/105227 on behalf of https://github.com/atalman due to trying to mitigate ci sev #105248 ([comment](https://github.com/pytorch/pytorch/pull/105227#issuecomment-1636510935))
2023-07-14 22:28:35 +00:00
Nikita Shulga
c9c4f8efc3 [Reland] Update mypy to 1.4.1 (#105227)
This PR re-lands
- [Typing] Fix PEP 484 Violation (#105022)
- Update mypy to 1.4.1 (#91983)

That were reverted due to the conflict with internal source repo.

Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  - Add assert it `torch/optim/optimizer.py` that Optional list is not None
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105227
Approved by: https://github.com/atalman, https://github.com/albanD, https://github.com/Skylion007
2023-07-14 20:45:12 +00:00
PyTorch MergeBot
3c5a494d7a Revert "Update mypy to 1.4.1 (#91983)"
This reverts commit 634659e262.

Reverted https://github.com/pytorch/pytorch/pull/91983 on behalf of https://github.com/malfet due to It's dependent change was reverted, so reverting this one as well, to keep CI clean ([comment](https://github.com/pytorch/pytorch/pull/91983#issuecomment-1636059709))
2023-07-14 15:59:16 +00:00
Nikita Shulga
634659e262 Update mypy to 1.4.1 (#91983)
Mostly fixes for PEP-484 violation (i.e. when default arg is set to None, but type is not annotated as optional)
Plus few real fixes:
  - Add missing `_get_upgraders_entry_map` to `torch/_C/__init__.pyi`
  - Add missing return statement to `torch._export. deserialize_graph`
  - Fix error message in `torch.ao.ns.fx.weight_utils.get_lstm_mod_weights`
  -
TODO (in followup PR):
  - Fix erroneous `isinstance` check in `torch/ao/quantization/_pt2e/qat_utils.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91983
Approved by: https://github.com/kit1980, https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/thiagocrepaldi, https://github.com/aaronenyeshi
2023-07-13 16:30:36 +00:00
shibo19
58feefa4ed add custom device support for special nn.modules (#103419)
Fixes #103818
1. for some special nn.Modules, there are checks which only support cuda, so I add `privateuse1` check.
2. when get the device type for `privateuse1` by `torch._C._get_privateuse1_backend_name()`, it will get error in `torch.jit.script`, so I add a global variable to avoid this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103419
Approved by: https://github.com/albanD
2023-06-26 00:58:29 +00:00
Xuehai Pan
5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00
PyTorch MergeBot
f8e641bad4 Revert "Make ModuleList derive from Sequence[T] and type it appropriately (#89135)"
This reverts commit d0bfd79f3d.

Reverted https://github.com/pytorch/pytorch/pull/89135 on behalf of https://github.com/albanD due to Is actually breaking user code
2023-01-12 22:04:02 +00:00
Alex Silverstein
d0bfd79f3d Make ModuleList derive from Sequence[T] and type it appropriately (#89135)
I see https://github.com/pytorch/pytorch/issues/53103 says this might be problematic, but I'm a bit confused at this point, because it looks like ModuleList does in fact already adhere to the Sequence API

The big win here is that for homogenous ModuleLists, you now get typing for individual members, e.g.
`ModuleList([Linear(), Linear(), Linear()])[1]` properly has type `Linear`

If this looks good, I can do a followup PR to do similarly for `ModuleDict` and `Parameter[List,Dict]`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89135
Approved by: https://github.com/albanD
2023-01-11 21:21:32 +00:00
Vlad Lialin
0b255b3f80 Better __repr__ for ModuleList (#90452)
## Problem
When models have a lot of complex repeated layers, `print(module)` output becomes unfeasible to work with. For example, current output of `__repr__` for `t5-small` is `715 ` lines long.

## Solution
Using better `__repr__` it becomes `135`. For `t5-large`, current `__repr__` prints `1411` lines. Better `__repr__` — `135`. Same numer as for t5-small, because most of the layers are just repeated. For `EleutherAI/gpt-j-6B` number of lines reduces form `483` to just `24`.

Here's how it works: when ModuleList items have exactly the same `__repr__` instead of printing both of them, it prints f`N x {repr(item)}`. Current code supports cases when the same ModuleList has multiple repeating items, which is especially useful when first/last layer of a block is different from the reset of them.

Better `__repr__` should make model prints smaller, more beautiful and significantly more useful by highlighting the difference between repeated blocks instead of losing it in a wall of text.

## Motivating real-life example.

You can try it out in this [colab notebook](https://colab.research.google.com/drive/1PscpX_K1UemIDotl2raC4QMy_pTqDq7p?usp=sharing).

Current `__repr__` of gpt-j-6b output it too big to add it to this PR description:
```
GPTJModel(
  (wte): Embedding(50400, 4096)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): GPTJMLP(
        (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
        (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (1): GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): GPTJMLP(
        (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
        (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (2): GPTJBlock(
...
```

Better `__repr__` output looks like this:
```
GPTJModel(
  (wte): Embedding(50400, 4096)
  (drop): Dropout(p=0.0, inplace=False)
  (h): ModuleList(
    (0-27): 28 x GPTJBlock(
      (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
      (attn): GPTJAttention(
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): GPTJMLP(
        (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
        (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90452
Approved by: https://github.com/albanD
2022-12-26 17:05:14 +00:00
Fabian Ricardo Latorre Gomez
9c9f424817 modify the signature of method __getitem__ from ModuleList (#83799)
The type of the parameter idx can be either slice or int. The same for the `Sequential` class

Fixes #83797

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83799
Approved by: https://github.com/malfet, https://github.com/albanD
2022-08-22 19:48:49 +00:00
ProGamerGov
357b7d589c Fix docstring inconsistencies: string -> str, boolean -> bool (#82410)
### Description

Throughout the PyTorch docs and codebase, the `string` type in docstrings is referred to by two separate names. This leads to inconsistent docs, like you can see here: https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html#torch.nn.Conv3d

This PR fixes this issue by ensuring that all mentions of the string type in docstrings, are using the same format that Sphinx generates hyperlinks for.

### Testing
No testing should be required for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82410
Approved by: https://github.com/jbschlosser
2022-07-28 21:29:57 +00:00
Khushi Agrawal
050aec1805 [nn] add pop to sequential and ModuleList (#81601)
Follows #71329

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81601
Approved by: https://github.com/albanD
2022-07-25 19:32:32 +00:00
Ansh Radhakrishnan
110cd724fc [nn] Add support for +=, * and *= operations for nn.Sequential objects (#81279)
Fixes 71329

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81279
Approved by: https://github.com/albanD
2022-07-25 15:48:47 +00:00
Khushi Agrawal
dced803339 [nn] add insert method to sequential class (#81402)
Follows #71329

cc @kshitij12345
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81402
Approved by: https://github.com/albanD
2022-07-20 14:45:52 +00:00
Khushi Agrawal
2c0b11b43b [nn] implement extend method to sequential class (#81179)
Follows #71329

cc @kshitij12345 :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81179
Approved by: https://github.com/albanD
2022-07-20 05:33:41 +00:00
Khushi Agrawal
3da8c909da [nn] add + operator for torch.nn.Sequential to concatenate (#81170)
Fixes #78512

#### TODO
- [x] add tests

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81170
Approved by: https://github.com/albanD
2022-07-11 17:49:58 +00:00
anjali411
bda04e9f5e Add __all__ for torch.optim and torch.nn.modules modules (#80237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237
Approved by: https://github.com/albanD
2022-06-24 21:34:10 +00:00
Edward Z. Yang
c20969c40c Fix ParameterList printing meta tensor
Fixes https://github.com/pytorch/pytorch/issues/78250

There are actually two bugs.  First, the crash is caused
by TensorOptions::backend incorrectly reporting noexcept when
it can failed.  Second, ParameterList is using torch.tensortype
for no good reason; we can just print the dtype instead.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78529

Approved by: https://github.com/albanD
2022-06-01 00:46:52 +00:00
纪少敏
29de7924a9 Fix parameterlist dir func error (#74404)
Fixes #[74404](https://github.com/pytorch/pytorch/issues/74404)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74997
Approved by: https://github.com/albanD
2022-04-04 20:13:11 +00:00
Alban Desmaison
7035738b50 Change ParameterList and ParameterDict to be able to contain any kind of objects (#70499)
Summary:
The only difference with plain list/dict now is that nn.Parameters are
handled specially and registered as parameters properly.

test_nn and parametrization works locally.
Will see in CI if DP is fixed as well.

Tentative fix for https://github.com/pytorch/pytorch/issues/36035

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70499

Reviewed By: jbschlosser, alexeib

Differential Revision: D34005332

Pulled By: albanD

fbshipit-source-id: 7e76b0873d0fec345cb537e2a6ecba0258e662b9
(cherry picked from commit dc1e6f8d86)
2022-02-09 18:52:29 +00:00
Jake Tae
ca61292465 Add append method for nn.Sequential (#71326)
Summary:
Partially addresses https://github.com/pytorch/pytorch/issues/71249, and potentially supersedes https://github.com/pytorch/pytorch/pull/20274.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71326

Reviewed By: cpuhrsch

Differential Revision: D33855047

Pulled By: jbschlosser

fbshipit-source-id: a3a682e206f93b4c52bc3405e2f7b26aea6635ea
(cherry picked from commit c0b27bbf2a)
2022-01-31 16:54:12 +00:00
Jake Tae
eac3decf93 ModuleList concatenation (#70887)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/70441.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70887

Reviewed By: ejguan

Differential Revision: D33555431

Pulled By: albanD

fbshipit-source-id: ce42459ee46a611e98e89f02686acbac16b6b668
2022-01-13 15:31:07 -08:00
Pascal
276253b164 Fixed wrong return type in ModuleList getitem (#69083)
Summary:
Fixes typing error:
`Expected type ‘Iterable’ (matched generic type ‘Iterable[_T1]’), got ‘Module’ instead.
`

see: https://discuss.pytorch.org/t/modulelist-typing-error-not-an-iterable/138137/5 :

To reproduce (e.g. with mypy/pycharm):

```python
import torch.nn as nn
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.module_list = nn.ModuleList(
            [nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
        )

    def forward(self, batch):
        for i in self.module_list[1:4]:
            pass
        return batch
model = Model()
out = model(torch.randn(1, 1))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69083

Reviewed By: davidberard98

Differential Revision: D33279114

Pulled By: jbschlosser

fbshipit-source-id: 90d74e76602163586b6ff4c49613a2694a9af37c
2021-12-22 11:38:17 -08:00