Commit Graph

20 Commits

Author SHA1 Message Date
Matthew Hoffman
6cabc105bb Merge type stubs for torch.nn.parallel (#101528)
Fixes #91648

As explained in the tracking issue, the incomplete type stubs in `torch/nn/parallel` mask `DataParallel` methods relevant for subclassing and also mask type issues present in the code as well.

One notable change here is the addition of [`allow_redefinition = True`](https://mypy.readthedocs.io/en/stable/config_file.html#confval-allow_redefinition) in `mypy.ini`, which allows for a common pattern:

> Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.

This is added specifically to allow for the type narrowing of `device_ids` in `torch.nn.parallel.data_parallel.data_parallel` from `Sequence[Union[int, torch.device]]` to `Sequence[int]`.

Other than this, there are various renamings and `type: ignore` comments added to bypass errors that arose from the merging.

@ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101528
Approved by: https://github.com/ezyang
2023-05-24 16:52:13 +00:00
Kim,Won-Joong
c47cf9bc7f Update parallel_apply.py for assertion error when len(modules) != len(inputs) (#94671)
Print the result why it is wrong.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94671
Approved by: https://github.com/ngimel, https://github.com/kit1980
2023-03-21 17:46:23 +00:00
Aaron Gokaslan
1e2d82b8e4 [BE] Merge isinstance calls together (#94419)
Simplify and speeds up isinstance calls by checking for multiple types at the same time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94419
Approved by: https://github.com/ezyang
2023-02-09 00:47:26 +00:00
Jeff Daily
04838696b0 parallel_apply should forward current streams to worker threads (#78824)
#71033 moved test_data_parallel_module et al under `instantiate_device_type_tests`.  This had the side effect of now running the tests on a non-default stream.  The parallel_apply creates new threads, one per device, but does not forward the thread local current streams from the parent thread.  This defaults the new per-device threads to use the null stream.  The null stream will not sync with the non-default non-blocking streams, resulting in errors when these tests assert tensors are equal.

CC @janeyx99
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78824
Approved by: https://github.com/pruthvistony, https://github.com/janeyx99
2022-07-20 01:34:21 +00:00
Alexander Grund
5b0f400488 Replace list(map(...)) constructs by list comprehensions (#46461)
Summary:
As discussed in https://github.com/pytorch/pytorch/issues/46392 this makes the code more readable and possibly more performant.

It also fixes a bug detected by this where the argument order of `map` was confused: 030a24906e (diff-5bb26bd3a23ee3bb540aeadcc0385df2a4e48de39f87ed9ea76b21990738fe98L1537-R1537)

Fixes https://github.com/pytorch/pytorch/issues/46392

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46461

Reviewed By: ailzhang

Differential Revision: D24367015

Pulled By: ezyang

fbshipit-source-id: d55a67933cc22346b00544c9671f09982ad920e7
2020-10-19 18:42:49 -07:00
Hongfei XU
f02753fabb Support AMP in nn.parallel (#43102)
Summary:
Take care of the state of autocast in `parallel_apply`, so there is no need to decorate model implementations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43102

Reviewed By: ngimel

Differential Revision: D23294610

Pulled By: mrshenli

fbshipit-source-id: 0fbe0c79de976c88cadf2ceb3f2de99d9342d762
2020-08-25 08:38:49 -07:00
Jan Schlüter
0bc90194fb Catch and print exception traceback in parallel_apply() workers (#18055)
Summary:
When an exception occurs in one of the modules passed to `parallel_apply()`, it is caught and re-raised in the main thread. This preserves the original exception type and message, but has the traceback point at the position where it's re-raised, rather than the original point of failure.

This PR saves the exception information required to generate the traceback, and includes the original traceback in the message of the exception raised in the main thread.

Before:
```
  ...
  File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File ".../torch/nn/parallel/parallel_apply.py", line 84, in parallel_apply
    raise output
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```

After:
```
  ...
  File ".../torch/nn/parallel/data_parallel.py", line 153, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File ".../torch/nn/parallel/parallel_apply.py", line 88, in parallel_apply
    ''.join(traceback.format_exception(*exc_info)))
RuntimeError: Caught exception in replica 0. Original traceback and message:
Traceback (most recent call last):
  ...
  File "../models/foo.py", line 319, in bar
    baz = asdf / ghij[:, np.newaxis]
RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
```

I took care to raise an exception of the original type (in case the main code checks for that), but replaced the message. It helped me find a bug that did not occur outside `data_parallel()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18055

Differential Revision: D16444972

Pulled By: zhangguanheng66

fbshipit-source-id: ec436c9d4677fad18106a8046cfa835a20a101ce
2019-07-26 11:41:22 -07:00
Wei Yang
54107ae8cf convert output_device at data_parallel from torch.device to index (#10189)
Summary:
- fixes #9984
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10189

Differential Revision: D9545390

Pulled By: weiyangfb

fbshipit-source-id: 3a6a705437553ba319e9fd4b7f676ff73857a27e
2018-09-11 20:27:07 -07:00
Tongzhou Wang
35f08b930d
Allow parallel_apply to take in list[Tensor] (#8047) 2018-06-06 13:49:52 -04:00
Tongzhou Wang
1c01eabd3c
Codemod to update our codebase to 0.4 standard (#6641)
* Codemod to update our codebase to 0.4 standard

* Update some of the test scri[ts

* remove Variable in test_clip_grad_value

* fix _symbolic_override_wrapper_maker
2018-04-17 22:06:54 -04:00
Sam Gross
d605058212
Replace Variable.volatile with torch.no_grad() (#3970)
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes #3627
2017-12-18 15:46:13 -05:00
Adam Paszke
4af40e3471 Let parallel_apply accept arbitrary inputs 2017-07-20 01:45:57 -04:00
Christian Sarofeen
3748b6d3eb Data parallel fix for https://github.com/pytorch/pytorch/issues/1857 (#1880)
* Data parallel fix for https://github.com/pytorch/pytorch/issues/1857
searches recursively for variable in input

* parallel_apply.py lint
2017-07-05 11:46:00 -04:00
Nick Hynes
274b5c9003 Allow unhashable inputs to parallel_apply 2017-04-01 20:11:20 +02:00
Sam Gross
e50a1f19b3 Use streams in scatter to overlap copy with compute 2017-03-14 22:46:07 +01:00
Christian Sarofeen
b1ae7f90d5 Added functionality for data parallel table (#843) 2017-03-05 02:35:46 +01:00
Adam Paszke
876202503f Support multiple inputs in data parallel 2017-02-20 23:28:31 -08:00
Luke Yeager
e7c1e6a8e3 [pep8] Fix most lint automatically with autopep8
Here's the command I used to invoke autopep8 (in parallel!):

    git ls-files | grep '\.py$' | xargs -n1 -P`nproc` autopep8 -i

Several rules are ignored in setup.cfg. The goal is to let autopep8
handle everything which it can handle safely, and to disable any rules
which are tricky or controversial to address. We may want to come back
and re-enable some of these rules later, but I'm trying to make this
patch as safe as possible.

Also configures flake8 to match pep8's behavior.

Also configures TravisCI to check the whole project for lint.
2017-01-28 01:15:51 +01:00
Adam Paszke
80a827d3da Fix data_parallel bugs 2016-11-23 18:48:41 +01:00
Adam Paszke
3eac7164f4 Add data parallel functions to nn 2016-09-27 15:45:45 -07:00