Commit Graph

14 Commits

Author SHA1 Message Date
Brian Wignall
f326045b37 Fix typos, via a Levenshtein-type corrector (#31523)
Summary:
Should be non-semantic.

Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos, with https://github.com/bwignall/typochecker to help automate the checking.

Uses an updated version of the tool used in https://github.com/pytorch/pytorch/pull/30606 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31523

Differential Revision: D19216749

Pulled By: mrshenli

fbshipit-source-id: 7fd489cb9a77cd7e4950c1046f925d57524960ea
2020-01-17 16:03:19 -08:00
Heungsub Hans Lee
fa251cfd97 Fully deprecate variadic inputs of checkpoint_sequential (#25985)
Summary:
To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985

Differential Revision: D18809875

Pulled By: albanD

fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
2019-12-05 09:23:28 -08:00
Brian Wignall
e7fe64f6a6 Fix typos (#30606)
Summary:
Should be non-semantic.

Uses https://en.wikipedia.org/wiki/Wikipedia:Lists_of_common_misspellings/For_machines to find likely typos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30606

Differential Revision: D18763028

Pulled By: mrshenli

fbshipit-source-id: 896515a2156d062653408852e6c04b429fc5955c
2019-12-02 20:17:42 -08:00
Hans Lee
ffdce79078 Deprecate variadic inputs of checkpoint_sequential (#21006)
Summary:
I've reported inconsistency between `checkpoint_sequential` and `nn.Sequential` at https://github.com/pytorch/pytorch/issues/19260. Both should provide the same input signature but they don't. I think the consistency is important and I agree with apaszke that `nn.Sequential`'s semantics should be kept instead of `checkpoint_sequential`.

I hope `checkpoint_sequential` raises `TypeError` on variadic arguments since PyTorch 1.2.0. But for now, it's okay just to warn as `DeprecationWarning`. I've talked about this approach with soumith.

Please review this pull request. Any comment will be my pleasure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21006

Differential Revision: D15530801

Pulled By: soumith

fbshipit-source-id: 0ceb2cc6a17dcc547d0d00ebaf9df8603be53183
2019-05-28 21:33:45 -07:00
Choongwoo Han
40074d647c Allow None for checkpoint (#17969)
Summary:
Currently, we cannot run a checkpointed function with None argument.

```python
out = torch.utils.checkpoint.checkpoint(run_fn, input_var, None)
```

```
  File "/home/tunz/anaconda3/envs/torchdev/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 14, in detach_variable
    x = inp.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
```

This PR makes checkpoint function to safely handle None argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17969

Differential Revision: D14475148

Pulled By: ezyang

fbshipit-source-id: 9afe9e9aac511a6df1e1620e9ac341536890d451
2019-03-15 07:38:41 -07:00
Michael Carilli
5d3a347685 Stashing checkpointing RNG states based on devices of arg tensors (#14518)
Summary:
This PR intends to address apaszke's concerns in https://github.com/pytorch/pytorch/pull/14253#issuecomment-441740016.  Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way.

Additionally, the checkpointing function stashes and restores the RNG states of
1. devices associated with all input tensor args to run_fn as well as
2. the current device.

I could easily change this to only save and restore the RNG states associated 1. alone.  This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active.

I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py).  I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14518

Differential Revision: D13356210

Pulled By: ezyang

fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
2018-12-11 09:48:45 -08:00
Andy Chen
33ea7eafef Make checkpoint_sequential work with multiple arguments (#14278)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278

In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.

We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.

Reviewed By: ezyang

Differential Revision: D13144672

fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
2018-12-04 18:47:43 -08:00
Michael Carilli
c36156eded Option to preserve bitwise accuracy of gradient checkpointed vs non-checkpointed dropout (#14253)
Summary:
This issue was noticed, and fix proposed, by raulpuric.

Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward.  This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.

The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.**  The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.

Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive.  However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.

I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list).  Python 3 seems happy with it.
Edit:  It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification).  `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage.  I'm open to suggestions (a global flag perhaps)?

**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253

Differential Revision: D13166665

Pulled By: soumith

fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7
2018-11-23 08:09:43 -08:00
Yangqing Jia
c47f680086 arc lint torch/utils (#13141)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13141

This is an example diff to show what lint rules are being applied.

Reviewed By: mingzhe09088

Differential Revision: D10858478

fbshipit-source-id: cbeb013f10f755b0095478adf79366e7cf7836ff
2018-10-25 14:59:03 -07:00
Thomas Viehmann
3799b10c44 various documentation formatting (#9359)
Summary:
This is a grab-bag of documentation formatting fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9359

Differential Revision: D8831400

Pulled By: soumith

fbshipit-source-id: 8dac02303168b2ea365e23938ee528d8e8c9f9b7
2018-07-13 02:48:25 -07:00
vfdev
2dc177ac50 Update checkpoint.py (#6943) 2018-04-25 08:43:58 -04:00
Priya Goyal
7d32f6fdc3 Adding runtime warning for checkpointing inputs to have requires_grad=True (#6883)
* Adding the warning for the checkpointing inputs to have requires_grad=True

* fix bug
2018-04-23 22:43:35 -04:00
Tongzhou Wang
c2187790e3 Improve utils.checkpoint docs (#6526)
* improve util.checkpoint docs

* change volatile to no_grad, and add more explanation

* address comments
2018-04-12 16:59:06 -04:00
Priya Goyal
e3196e0ea8
[Re-checkpointing] Autograd container for trading compute for memory (#6467)
* Autograd container for trading compute for memory

* add a unit test for checkpoint

* address comments

* address review comments

* adding some docs for the checkpoint api

* more comments

* more comments

* repro bug

* Fix a subtle bug/apply some review comments

* Update checkpoint.py

* Run everything in grad mode

* fix flake and chunk=1

* use imperative backward as per discussion

* remove Variable and also add models and test for models

* Add a simple thread local variable to check for autograd grad mode

* remove models and models test after debugging

* address review comments

* address more comments

* address more comments
2018-04-10 15:26:24 -04:00