Commit Graph

93 Commits

Author SHA1 Message Date
Khushi Agrawal
2c0b11b43b [nn] implement extend method to sequential class (#81179)
Follows #71329

cc @kshitij12345 :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81179
Approved by: https://github.com/albanD
2022-07-20 05:33:41 +00:00
Khushi Agrawal
3da8c909da [nn] add + operator for torch.nn.Sequential to concatenate (#81170)
Fixes #78512

#### TODO
- [x] add tests

cc @kshitij12345!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81170
Approved by: https://github.com/albanD
2022-07-11 17:49:58 +00:00
anjali411
bda04e9f5e Add __all__ for torch.optim and torch.nn.modules modules (#80237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80237
Approved by: https://github.com/albanD
2022-06-24 21:34:10 +00:00
Edward Z. Yang
c20969c40c Fix ParameterList printing meta tensor
Fixes https://github.com/pytorch/pytorch/issues/78250

There are actually two bugs.  First, the crash is caused
by TensorOptions::backend incorrectly reporting noexcept when
it can failed.  Second, ParameterList is using torch.tensortype
for no good reason; we can just print the dtype instead.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78529

Approved by: https://github.com/albanD
2022-06-01 00:46:52 +00:00
纪少敏
29de7924a9 Fix parameterlist dir func error (#74404)
Fixes #[74404](https://github.com/pytorch/pytorch/issues/74404)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74997
Approved by: https://github.com/albanD
2022-04-04 20:13:11 +00:00
Alban Desmaison
7035738b50 Change ParameterList and ParameterDict to be able to contain any kind of objects (#70499)
Summary:
The only difference with plain list/dict now is that nn.Parameters are
handled specially and registered as parameters properly.

test_nn and parametrization works locally.
Will see in CI if DP is fixed as well.

Tentative fix for https://github.com/pytorch/pytorch/issues/36035

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70499

Reviewed By: jbschlosser, alexeib

Differential Revision: D34005332

Pulled By: albanD

fbshipit-source-id: 7e76b0873d0fec345cb537e2a6ecba0258e662b9
(cherry picked from commit dc1e6f8d86)
2022-02-09 18:52:29 +00:00
Jake Tae
ca61292465 Add append method for nn.Sequential (#71326)
Summary:
Partially addresses https://github.com/pytorch/pytorch/issues/71249, and potentially supersedes https://github.com/pytorch/pytorch/pull/20274.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71326

Reviewed By: cpuhrsch

Differential Revision: D33855047

Pulled By: jbschlosser

fbshipit-source-id: a3a682e206f93b4c52bc3405e2f7b26aea6635ea
(cherry picked from commit c0b27bbf2a)
2022-01-31 16:54:12 +00:00
Jake Tae
eac3decf93 ModuleList concatenation (#70887)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/70441.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70887

Reviewed By: ejguan

Differential Revision: D33555431

Pulled By: albanD

fbshipit-source-id: ce42459ee46a611e98e89f02686acbac16b6b668
2022-01-13 15:31:07 -08:00
Pascal
276253b164 Fixed wrong return type in ModuleList getitem (#69083)
Summary:
Fixes typing error:
`Expected type ‘Iterable’ (matched generic type ‘Iterable[_T1]’), got ‘Module’ instead.
`

see: https://discuss.pytorch.org/t/modulelist-typing-error-not-an-iterable/138137/5 :

To reproduce (e.g. with mypy/pycharm):

```python
import torch.nn as nn
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.module_list = nn.ModuleList(
            [nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 8), nn.Linear(8, 1)]
        )

    def forward(self, batch):
        for i in self.module_list[1:4]:
            pass
        return batch
model = Model()
out = model(torch.randn(1, 1))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69083

Reviewed By: davidberard98

Differential Revision: D33279114

Pulled By: jbschlosser

fbshipit-source-id: 90d74e76602163586b6ff4c49613a2694a9af37c
2021-12-22 11:38:17 -08:00
Albert Liang
0d06616c47 Add dict methods to ParameterDict (#69403)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/68476

We implemented all of the following `dict` methods for `ParameterDict`
- `get `
- `setdefault`
- `popitem`
- `fromkeys`
- `copy`
- `__or__`
- `__ior__`
- `__reversed__`
- `__ror__`

The behavior of these new methods matches the expected behavior of python `dict` as defined by the language itself: https://docs.python.org/3/library/stdtypes.html#typesmapping

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69403

Reviewed By: albanD

Differential Revision: D33187111

Pulled By: jbschlosser

fbshipit-source-id: ecaa493837dbc9d8566ddbb113b898997e2debcb
2021-12-17 10:15:47 -08:00
Nikita Shulga
4e94e84f65 Type annotate torch.nn.Module ctor (#61334)
Summary:
Annotate generic types
Fix some type violations
Override `_modules` and `_parameters` in `Sequential`, `ModuleList`, `ModuleDict`, etc

Fixes https://github.com/pytorch/pytorch/issues/45497

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61334

Reviewed By: albanD

Differential Revision: D29579533

Pulled By: malfet

fbshipit-source-id: 5cd8ca918b260ca35cfdd873dee8851d39d17de2
2021-07-16 13:59:06 -07:00
Shai Bagon
a583b9cd86 Fixing "naive" forward of ModuleList and `ModuleDict (#48785)
Summary:
**Goal:** Making sure "calling"/"forwarding" a `ModuleList` or `ModuleDict` produce the intended `NotImpmentedError`.

**Current behavior:**
Currently, when naively calling `forward`  user ends up with the confusing error message:
```python
TypeError: forward() takes 1 positional argument but 2 were given
```
Instead of the intended `NotImplementedError.`
This minor issue was brought up by vadimkantorov in issue https://github.com/pytorch/pytorch/issues/37718 [here][1], also by a confused stackoverflow user [here][2].

**What this PR includes:**
Remove `forward` altogether from `ModuleList` and `ModuleDict` to fall back on the `_forward_unimplemented` of `Module` that properly throws `NotImplementedError` regardless of input arguments.

Appropriate test was added to `test_nn.py`

Fixes previous PR https://github.com/pytorch/pytorch/issues/48698 and PR https://github.com/pytorch/pytorch/issues/48783 (third time's a charm? I'm really sorry for the mess)

Test added according to ngimel [request][3].

[1]: https://github.com/pytorch/pytorch/issues/37718#issuecomment-736333345
[2]: https://stackoverflow.com/q/65096679/1714410
[3]: https://github.com/pytorch/pytorch/pull/48698#issuecomment-737398693

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48785

Reviewed By: zhangguanheng66

Differential Revision: D25359759

Pulled By: jbschlosser

fbshipit-source-id: 28f82386f2e9a2a9b0b0b81b16dba6b79398bd34
2021-04-21 10:43:07 -07:00
Ansley Ussery
b032316c41 Improve nn.Sequential documentation (#53380)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53380

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb

Differential Revision: D26849861

Pulled By: ansley

fbshipit-source-id: 2add8c73ae421332ed1c03340806e25656bafabb
2021-03-24 13:02:43 -07:00
Sam Estep
8c798e0622 Forbid trailing whitespace (#53406)
Summary:
Context: https://github.com/pytorch/pytorch/pull/53299#discussion_r587882857

These are the only hand-written parts of this diff:
- the addition to `.github/workflows/lint.yml`
- the file endings changed in these four files (to appease FB-internal land-blocking lints):
  - `GLOSSARY.md`
  - `aten/src/ATen/core/op_registration/README.md`
  - `scripts/README.md`
  - `torch/csrc/jit/codegen/fuser/README.md`

The rest was generated by running this command (on macOS):
```
git grep -I -l ' $' -- . ':(exclude)**/contrib/**' ':(exclude)third_party' | xargs gsed -i 's/ *$//'
```

I looked over the auto-generated changes and didn't see anything that looked problematic.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/53406

Test Plan:
This run (after adding the lint but before removing existing trailing spaces) failed:
- https://github.com/pytorch/pytorch/runs/2043032377

This run (on the tip of this PR) succeeded:
- https://github.com/pytorch/pytorch/runs/2043296348

Reviewed By: walterddr, seemethere

Differential Revision: D26856620

Pulled By: samestep

fbshipit-source-id: 3f0de7f7c2e4b0f1c089eac9b5085a58dd7e0d97
2021-03-05 17:22:55 -08:00
Chester Liu
58eb23378f Clean up usage of torch._six partially (#49785)
Summary:
See https://github.com/pytorch/pytorch/issues/42919

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49785

Reviewed By: mruberry

Differential Revision: D25963833

Pulled By: bugra

fbshipit-source-id: 11c90d6b8d3f206c9d0a4d8621b773beb10c6ba2
2021-02-08 13:58:34 -08:00
Guilherme Leobas
a9e46f1413 add type annotations to torch.nn.modules.container (#48969)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/48968

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48969

Reviewed By: mrshenli

Differential Revision: D25728987

Pulled By: walterddr

fbshipit-source-id: 02c3aa2078f4ed6cc6edd90ffe1177d789c328a9
2021-01-19 15:12:17 -08:00
Samuel Marks
e6779d4357 [*.py] Rename "Arguments:" to "Args:" (#49736)
Summary:
I've written custom parsers and emitters for everything from docstrings to classes and functions. However, I recently came across an issue when I was parsing/generating from the TensorFlow codebase: inconsistent use of `Args:` and `Arguments:` in its docstrings.

```sh
(pytorch#c348fae)$ for name in 'Args:' 'Arguments:'; do
    printf '%-10s %04d\n' "$name" "$(rg -IFtpy --count-matches "$name" | paste -s -d+ -- | bc)"; done
Args:      1095
Arguments: 0336
```

It is easy enough to extend my parsers to support both variants, however it looks like `Arguments:` is wrong anyway, as per:

  - https://google.github.io/styleguide/pyguide.html#doc-function-args @ [`ddccc0f`](https://github.com/google/styleguide/blob/ddccc0f/pyguide.md)

  - https://chromium.googlesource.com/chromiumos/docs/+/master/styleguide/python.md#describing-arguments-in-docstrings @ [`9fc0fc0`](https://chromium.googlesource.com/chromiumos/docs/+/9fc0fc0/styleguide/python.md)

  - https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html @ [`c0ae8e3`](https://github.com/sphinx-contrib/napoleon/blob/c0ae8e3/docs/source/example_google.rst)

Therefore, only `Args:` is valid. This PR replaces them throughout the codebase.

PS: For related PRs, see tensorflow/tensorflow/pull/45420

PPS: The trackbacks automatically appearing below are sending the same changes to other repositories in the [PyTorch](https://github.com/pytorch) organisation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/49736

Reviewed By: albanD

Differential Revision: D25710534

Pulled By: soumith

fbshipit-source-id: 61e8ff01abb433e9f78185c2d1d0cbd7c22c1619
2020-12-28 09:34:47 -08:00
CedricPicron
dc7ab46dcc Fix incorrect warnings in ParameterList/Dict (#48315)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/46983.

The solution is based of two components:

1. The introduction of the `_initialized` attribute. This will be used during ParameterList/Dict creation methods `__init__` (introduced in https://github.com/pytorch/pytorch/issues/47772) and  `__setstate__` to not trigger warnings when setting general `Module` attributes.
2. The introduction of the `not hasattr(self, key)` check to avoid triggering warnings when changing general `Module` attributes such as `.training` during the `train()` and `eval()` methods.

Tests related to the fix are added.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48315

Reviewed By: mrshenli

Differential Revision: D25130217

Pulled By: albanD

fbshipit-source-id: 79e2abf1eab616f5de74f75f370c2fe149bed4cb
2020-12-01 07:08:33 -08:00
albanD
233192be73 Make sure valid ParameterList/Dict don't warn on creation (#47772)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/46983

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47772

Reviewed By: zou3519

Differential Revision: D24991341

Pulled By: albanD

fbshipit-source-id: 0fa21192f529a016048e3eef88c5a8f3cbb3c235
2020-11-16 13:16:59 -08:00
albanD
e155fbe915 add warning when ParameterList/Dict is used with DataParallel (#44405)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44405

Test Plan: Imported from OSS

Reviewed By: agolynski

Differential Revision: D23783987

Pulled By: albanD

fbshipit-source-id: 5018b0d381cb09301d2f88a98a910854f740ace1
2020-09-22 08:58:00 -07:00
wudenggang
9600ed9af3 typo fixes (#41632)
Summary:
typo fixes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41632

Reviewed By: ezyang

Differential Revision: D22617827

Pulled By: mrshenli

fbshipit-source-id: c2bfcb7cc36913a8dd32f13fc9adc3aa0a9b682f
2020-07-20 07:23:00 -07:00
Liu
54d7a1e3f4 Fix module dict key ordering (#40905)
Summary:
fix https://github.com/pytorch/pytorch/issues/40227
Removed the sorting operation both in ModuleDict class, updated the docstring.
Also remove a sort operation in corresponding unit test, which will lead to unit test fail.

BC Note: Python version after 3.6, the plain dict will preserve the order of keys.
example:
For a python 3.6+ user, if he is initial a ModuleDict instance using plain python dict:
{
"b": torch.nn.MaxPool2d(3),
"a": torch.nn.MaxPool2d(3)
}
, he will get a ModuleDict which preserve the order:
ModuleDict(
(b): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
(a): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)

For a python 3.5 user, if we maintain the same input, then the output ModuleDict could be:
ModuleDict(
(a): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
(b): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40905

Differential Revision: D22357480

Pulled By: albanD

fbshipit-source-id: 0e2502769647bb64f404978243ca1ebe5346d573
2020-07-06 06:40:48 -07:00
Edward Yang
eace053398 Move all torch.nn.modules type annotations inline (#38211)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38211

Just because the annotations are inline doesn't mean the files type
check; most of the newly annotated files have type errors and I
added exclusions for them in mypy.ini.  The payoff of moving
all of these modules inline is I can delete the relevant code
generation logic for the pyi files (which was added ignore
annotations that weren't actually relevant anymore.)

For the most part the translation was completely mechanical, but there
were two hairy issues.  First, I needed to work around a Python 3.6 and
earlier bug where Generic has a nontrivial metaclass.  This fix is in
torch/jit/__init__.py.  Second, module.py, we need to apply the same
fix for avoiding contravariance checks that the pyi file used to have;
this is done by declaring forward as a variable (rather than a
function), which appears to be sufficient enough to get mypy to not
contravariantly check input arguments.

Because we aren't actually typechecking these modules in most
cases, it is inevitable that some of these type annotations are wrong.
I slavishly copied the old annotations from the pyi files unless there
was an obvious correction I could make.  These annotations will probably
need fixing up later.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D21497397

Pulled By: ezyang

fbshipit-source-id: 2b08bacc152c48f074e7edc4ee5dce1b77d83702
2020-06-11 15:59:57 -07:00
Robert Wang
2b2d2168e8 Issue #27441 Fix: Bug in updating ModuleDict & ParameterDict (#27814)
Summary:
Fix a bug in `nn.ModuleDict.update` and `nn.ParameterDict.update` when passing another same dictionary as input.
Related issue: [Issue https://github.com/pytorch/pytorch/issues/27441](https://github.com/pytorch/pytorch/issues/27441)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27814

Differential Revision: D21518099

Pulled By: ezyang

fbshipit-source-id: 9e6bb6fcc26c8070e137e2e52c65f69a1fcaab37
2020-05-14 08:01:41 -07:00
songyouwei
e5218e3e12 Add missing error messages for container modules (#29991)
Summary:
Container `Module`s, including `ModuleList`, `ParameterList` and `ParameterDict`, should not be called like a regular `Module`.
This PR add error messages for these special modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29991

Differential Revision: D19698535

Pulled By: ezyang

fbshipit-source-id: fe156a0bbb033041086734b38f8c6fde034829bf
2020-02-13 21:34:27 -08:00
Alban Desmaison
81048c41ab remove simple .data from torch/nn
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/31482

Test Plan: Imported from OSS

Differential Revision: D19303243

Pulled By: albanD

fbshipit-source-id: 5afdfeb4b8382c09b9ec65acd545148ed76d4285
2020-01-15 12:40:38 -08:00
Elias Ellison
fbe90b65fa Cleanup special handling of Containers, allowing custom forwards (#28988)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28988

Make ModuleList, Sequential, ModuleDict go through the same pathway as other modules, cleaning up a bunch of code and allowing them to define custom forwards and other methods.

EDIT: Previously, we would ignore an nn.Sequential attribute if it was not in `__constants__` ("did you forget to add it to Constants"). This PR scripts it even if it is not in `__constants__`. Is that what we want?

Test Plan: Imported from OSS

Differential Revision: D18402821

Pulled By: eellison

fbshipit-source-id: dd4f28fb0df0d1ba4ad1b3bc34ba141959a433f7
2019-11-12 14:10:38 -08:00
Elias Ellison
3175f5543a Make nn.Sequential iterable (#28987)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28987

We have `__iter__` defined on nn.ModuleList. Chainer's `Sequential` defines `__iter__`. This will also be helpful in modules which extend `nn.Sequential` and define a custom forward, because they can use the `for x in self` syntax that is supported in both python & TorchScript.

Test Plan: Imported from OSS

Differential Revision: D18402822

Pulled By: eellison

fbshipit-source-id: 1ece0f891a9d37f401e232320f58b056d5481856
2019-11-12 14:10:34 -08:00
Elias Ellison
8f7020bbdb add support for ModuleDict (#25715)
Summary:
Add support for nn.ModuleDict in script. This is needed to support torchvision.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25715

Differential Revision: D17301826

Pulled By: eellison

fbshipit-source-id: 541b5477e980f519a8c3bbb1be91dac227f6d00f
2019-09-10 18:43:49 -07:00
Tongzhou Wang
6b8771a7a6 fix nn.Sequential doc
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19597

Differential Revision: D15042383

Pulled By: soumith

fbshipit-source-id: f912ed2a726a17fcc25795ff66b73ae4caacd247
2019-04-23 14:58:16 -07:00
Tongzhou Wang
1d827b7271 Further improvements of nn.container docs
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17731

Differential Revision: D14401894

Pulled By: soumith

fbshipit-source-id: cebb25859f78589cc4f4f8afb1e84c97f82b6962
2019-03-10 18:30:39 -07:00
Tongzhou Wang
0ed1b9fb98 Update ModuleDict doc about order
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/17717

Differential Revision: D14346557

Pulled By: ezyang

fbshipit-source-id: 2484c7d8105f9aa8bce5567d1fa2d4f587cc9cc2
2019-03-06 13:09:46 -08:00
ZhuBaohe
acf5ec07af Correct conv and pooling docstrings in nn module (#17052)
Summary:
This PR fix conv and pooling docstrings in nn module
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17052

Differential Revision: D14068566

Pulled By: ezyang

fbshipit-source-id: 3ec1de232ff6334b6a544dadefbb0ee6193d443a
2019-02-15 06:58:02 -08:00
Sasha Rush
dbe6a7a9ff Unify the shape notation for all of the pytorch modules (#15741)
Summary:
PR to update the shape notation for all of the torch.nn modules to take a unified form. The goal is to make these definitions machine-readable and those checkable by unifying the style across all of the different modules.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15741

Differential Revision: D13709601

Pulled By: ezyang

fbshipit-source-id: fb89a03903fdf0cd0dcf76f3e469b8582b2f3634
2019-01-17 10:32:14 -08:00
lyuwenyu
1b1cdd944c Keep ModuleList consistent with python list in __setitem__ function. (#13102)
Summary:
`ModuleList` class function `__setitem__` has implicit rist
```
In [26]: mlist = nn.ModuleList([nn.ReLU(), nn.Conv2d(10, 10, 3, 1)])

In [27]: mlist
Out[27]:
ModuleList(
  (0): ReLU()
  (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
)

In [28]: mlist[-1] = nn.ReLU()

In [29]: mlist
Out[29]:
ModuleList(
  (0): ReLU()
  (1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1))
  (-1): ReLU()
)

In [30]: mlist[-1]
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-30-229d1b6823a0> in <module>()
----> 1 mlist[-1]

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in __getitem__(self, idx)
    134             return ModuleList(list(self._modules.values())[idx])
    135         else:
--> 136             return self._modules[self._get_abs_string_index(idx)]
    137
    138     def __setitem__(self, idx, module):

KeyError: '2'

```

modified as
```
    def __setitem__(self, idx, module):
        idx = self._get_abs_string_index(idx)
        return setattr(self, str(idx), module)
```
to fix it.

```
In [31]: class NewModuleList(nn.ModuleList):
    ...:     def __setitem__(self, idx, module):
    ...:         idx = self._get_abs_string_index(idx)
    ...:         return setattr(self, str(idx), module)
    ...:

In [32]: mlist = NewModuleList([nn.ReLU(), nn.Conv2d(10, 10, 2, 1)])

In [33]: mlist[-1] = nn.ReLU()

In [34]: mlist
Out[34]:
NewModuleList(
  (0): ReLU()
  (1): ReLU()
)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13102

Differential Revision: D13092480

Pulled By: ezyang

fbshipit-source-id: 7ff7688f66e44bbd263a10d2d09db7bb0df4b749
2018-11-16 07:39:26 -08:00
Xingdong Zuo
e2bc95e1bd add ModuleList.insert (#11664)
Summary:
fixes #11652
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11664

Differential Revision: D9892845

Pulled By: ezyang

fbshipit-source-id: 2c910d6bc0b28a999e25beca6e398fd0f35535c5
2018-09-18 07:41:28 -07:00
nehz
91b6458e2d Container __getitem__ slicing for subclasses (#11694)
Summary:
Simple change to allow ModuleList subclasses's `__getitem__(slice)` to return class of subclass rather than ModuleList
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11694

Differential Revision: D9892824

Pulled By: ezyang

fbshipit-source-id: b75e9c196487f55cb93f0dab6c20d850e8e759ff
2018-09-18 01:26:18 -07:00
Jeff Smith
05e06f7de2 migrating deprecated calls without abc module for containers (#11515)
Summary:
Implementing #10540.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11515

Reviewed By: apaszke

Differential Revision: D9771045

Pulled By: jeffreyksmithjr

fbshipit-source-id: 85ea39abaa9b465805a969f122b626b11fc85ef6
2018-09-13 15:09:22 -07:00
Tongzhou Wang
de460c7ad3 Improvements on conv/pool/fold/stft/ParamDict docs (#11106)
Summary:
Also fixes some incorrect formula rendering.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11106

Differential Revision: D9752433

Pulled By: SsnL

fbshipit-source-id: 535fc8498638e8b645757fc7535d8771992b7d21
2018-09-11 08:56:21 -07:00
Changmao Cheng
7b375ed362 fix ParameterDict doc
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/9918

Differential Revision: D9026402

Pulled By: soumith

fbshipit-source-id: d0459dcda631e8921ab39725b9045e03960da5c9
2018-07-27 01:10:50 -07:00
Karan Dwivedi
97008a64a1 Add ModuleDict and ParameterDict containers (#8463)
Summary:
Addresses:

https://github.com/pytorch/pytorch/issues/4048 and https://github.com/pytorch/pytorch/pull/5297#issuecomment-394924139
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8463

Reviewed By: SsnL

Differential Revision: D8689291

Pulled By: ezyang

fbshipit-source-id: 47e67d9bae1b64ec10771a2c00c56229463b1598
2018-07-15 17:40:52 -07:00
Karan Dwivedi
37e526e1a8 Better print of nn Containers (#8939)
Summary:
Fix https://github.com/pytorch/pytorch/issues/8900

Waiting on https://github.com/pytorch/pytorch/pull/8463

1. Remove extra Line
2. ...
Closes https://github.com/pytorch/pytorch/pull/8939

Reviewed By: soumith

Differential Revision: D8687730

Pulled By: ezyang

fbshipit-source-id: 81c57a03683875704d537cb4585b11838f70df56
2018-06-29 08:24:09 -07:00
Kaiyu Shi
605307f8f3 Add support for printing extra information in Module and refactor redundant codes (#5936)
This PR enables users to print extra information of their subclassed nn.Module.
Now I simply insert the user-defined string at the ending of module name, which should be discussed in this PR.

Before this PR, users should redefine the __repr__ and copy&paste the source code from Module.

* Add support for extra information on Module

* Rewrite the repr method of Module

* Fix flake8

* Change the __repr__ to get_extra_repr in Linear

* Fix extra new-line for empty line

* Add test for __repr__ method

* Fix bug of block string indent

* Add indent for multi-line repr test.

* Address review comments

* Update tutorial for creating nn.Module

* Fix flake8, add extra_repr of bilinear

* Refactor DropoutNd

* Change to extra_repr in some Modules

* Fix flake8

* Refactor padding modules

* Refactor pooling module

* Fix typo

* Change to extra_repr

* Fix bug for GroupNorm

* Fix bug for LayerNorm
2018-04-02 13:52:33 -04:00
Sam Gross
82bdc51dd1
Use operator.index to convert indices to Python int (#5582)
This makes ParameterList, ModuleList, and Sequential convert PyTorch and
NumPy scalars to integers. This matches the behavior of Python lists.
2018-03-06 12:41:23 -05:00
Vishwak Srinivasan
318ae2085a Include __delitem__ for Sequential (#5233) 2018-02-14 13:04:27 +01:00
Kaiyu Shi
f796080781 Add assignment support for Sequential (#4931) 2018-02-07 02:22:25 +01:00
Stefan Otte
409b1c8319 Improve wording of Sequential docs (#4790) 2018-01-22 21:18:23 -05:00
Vishwak Srinivasan
123f49badb Add Slicing capabilities for Sequential, ModuleList and ParameterList (#4491) 2018-01-06 13:01:17 +01:00
David Pollack
47fadc3138 improvements to extend in ModuleList and ParameterList (#3505) 2017-11-29 20:46:39 +01:00
chenyuntc
9b54f8e59c ignore digit in container's __dir__ 2017-11-07 22:08:32 +01:00