Summary:
This PR allows Timer to collect deterministic instruction counts for (some) snippets. Because of the intrusive nature of Valgrind (effectively replacing the CPU with an emulated one) we have to perform our measurements in a separate process. This PR writes a `.py` file containing the Timer's `setup` and `stmt`, and executes it within a `valgrind` subprocess along with a plethora of checks and error handling. There is still a bit of jitter around the edges due to the Python glue that I'm using, but the PyTorch signal is quite good and thus this provides a low friction way of getting signal. I considered using JIT as an alternative, but:
A) Python specific overheads (e.g. parsing) are important
B) JIT might do rewrites which would complicate measurement.
Consider the following bit of code, related to https://github.com/pytorch/pytorch/issues/44484:
```
from torch.utils._benchmark import Timer
counts = Timer(
"x.backward()",
setup="x = torch.ones((1,)) + torch.ones((1,), requires_grad=True)"
).collect_callgrind()
for c, fn in counts[:20]:
print(f"{c:>12} {fn}")
```
```
812800 ???:_dl_update_slotinfo
355600 ???:update_get_addr
308300 work/Python/ceval.c:_PyEval_EvalFrameDefault'2
304800 ???:__tls_get_addr
196059 ???:_int_free
152400 ???:__tls_get_addr_slow
138400 build/../c10/core/ScalarType.h:c10::typeMetaToScalarType(caffe2::TypeMeta)
126526 work/Objects/dictobject.c:_PyDict_LoadGlobal
114268 ???:malloc
101400 work/Objects/unicodeobject.c:PyUnicode_FromFormatV
85900 work/Python/ceval.c:_PyEval_EvalFrameDefault
79946 work/Objects/typeobject.c:_PyType_Lookup
72000 build/../c10/core/Device.h:c10::Device::validate()
70000 /usr/include/c++/8/bits/stl_vector.h:std::vector<at::Tensor, std::allocator<at::Tensor> >::~vector()
66400 work/Objects/object.c:_PyObject_GenericGetAttrWithDict
63000 ???:pthread_mutex_lock
61200 work/Objects/dictobject.c:PyDict_GetItem
59800 ???:free
58400 work/Objects/tupleobject.c:tupledealloc
56707 work/Objects/dictobject.c:lookdict_unicode_nodummy
```
Moreover, if we backport this PR to 1.6 (just copy the `_benchmarks` folder) and load those counts as `counts_1_6`, then we can easily diff them:
```
print(f"Head instructions: {sum(c for c, _ in counts)}")
print(f"1.6 instructions: {sum(c for c, _ in counts_1_6)}")
count_dict = {fn: c for c, fn in counts}
for c, fn in counts_1_6:
_ = count_dict.setdefault(fn, 0)
count_dict[fn] -= c
count_diffs = sorted([(c, fn) for fn, c in count_dict.items()], reverse=True)
for c, fn in count_diffs[:15] + [["", "..."]] + count_diffs[-15:]:
print(f"{c:>8} {fn}")
```
```
Head instructions: 7609547
1.6 instructions: 6059648
169600 ???:_dl_update_slotinfo
101400 work/Objects/unicodeobject.c:PyUnicode_FromFormatV
74200 ???:update_get_addr
63600 ???:__tls_get_addr
46800 work/Python/ceval.c:_PyEval_EvalFrameDefault
33512 work/Objects/dictobject.c:_PyDict_LoadGlobal
31800 ???:__tls_get_addr_slow
31700 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::RecordFunction(at::RecordScope)
28300 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)
27800 work/Objects/object.c:_PyObject_GenericGetAttrWithDict
27401 work/Objects/dictobject.c:lookdict_unicode_nodummy
24115 work/Objects/typeobject.c:_PyType_Lookup
24080 ???:_int_free
21700 work/Objects/dictobject.c:PyDict_GetItemWithError
20700 work/Objects/dictobject.c:PyDict_GetItem
...
-3200 build/../c10/util/SmallVector.h:at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&, bool)
-3400 build/../aten/src/ATen/native/TensorIterator.cpp:at::TensorIterator::resize_outputs(at::TensorIteratorConfig const&)
-3500 /usr/include/c++/8/x86_64-redhat-linux/bits/gthr-default.h:std::unique_lock<std::mutex>::unlock()
-3700 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object**)
-4207 work/Objects/obmalloc.c:PyMem_Calloc
-4500 /usr/include/c++/8/bits/stl_vector.h:std::vector<at::Tensor, std::allocator<at::Tensor> >::~vector()
-4800 build/../torch/csrc/autograd/generated/VariableType_2.cpp:torch::autograd::VariableType::add__Tensor(at::Tensor&, at::Tensor const&, c10::Scalar)
-5000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKey)
-5300 work/Objects/listobject.c:PyList_New
-5400 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionParameter::check(_object*, std::vector<pybind11::handle, std::allocator<pybind11::handle> >&)
-5600 /usr/include/c++/8/bits/std_mutex.h:std::unique_lock<std::mutex>::unlock()
-6231 work/Objects/obmalloc.c:PyMem_Free
-6300 work/Objects/listobject.c:list_repeat
-11200 work/Objects/listobject.c:list_dealloc
-28900 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionSignature::parse(_object*, _object*, _object**, bool)
```
Remaining TODOs:
* Include a timer in the generated script for cuda sync.
* Add valgrind to CircleCI machines and add a unit test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44717
Reviewed By: soumith
Differential Revision: D24010742
Pulled By: robieta
fbshipit-source-id: df6bc765f8efce7193893edba186cd62b4b23623
Summary:
This PR cleans up some of the rough edges around `Timer` and `Compare`
* Moves `Measurement` to be dataclass based
* Adds a bunch of type annotations. MyPy is now happy.
* Allows missing entries in `Compare`. This is one of the biggest usability issues with `Compare` right now, both from an API perspective and because the current failure mode is really unpleasant.
* Greatly expands the testing of `Compare`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45361
Test Plan: Changes to Timer are covered under existing tests, changes to `Compare` are covered by the expanded `test_compare` method.
Reviewed By: bwasti
Differential Revision: D23966816
Pulled By: robieta
fbshipit-source-id: 826969f73b42f72fa35f4de3c64d0988b61474cd
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45188
This is a symbolically traceable alternative to Python's `assert`.
It should be useful to allow people who want to use FX to also
be able to assert things.
A bunch of TODO(before) land are inline - would love thoughts
on where is the best place for this code to live, and what this
function should be called (since `assert` is reserved).
Test Plan:
```
python test/test_fx.py TestFX.test_symbolic_trace_assert
```
Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D23861567
fbshipit-source-id: d9d6b9556140faccc0290eba1fabea401d7850de
Summary:
I noticed that the recently introduced adaptive_autorange tests occasionally timeout CI, and I've been meaning to improve the Timer tests for a while. This PR allows unit tests to swap the measurement portion of `Timer` with a deterministic mock so we can thoroughly test behavior without having to worry about flaky CI measurements. It also means that the tests can be much more detailed and still finish very quickly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45173
Test Plan: You're lookin' at it.
Reviewed By: ezyang
Differential Revision: D23873548
Pulled By: robieta
fbshipit-source-id: 26113e5cea0cbf46909b9bf5e90c878c29e87e88
Summary:
Fixes https://github.com/pytorch/pytorch/issues/43622
- Moves the model loading part of `torch.hub.load()` into a new `torch.hub.load_local()` function that takes in a path to a local directory that contains a `hubconf.py` instead of a repo name.
- Refactors `torch.hub.load()` so that it now calls `torch.hub.load_local()` after downloading and extracting the repo.
- Updates `torch.hub` docs to include the new function + minor fixes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44204
Reviewed By: malfet
Differential Revision: D23817429
Pulled By: ailzhang
fbshipit-source-id: 788fd83c87a94f487b558715b2809d346ead02b2
Summary:
Fixes https://github.com/pytorch/pytorch/issues/44219
Rebasing https://github.com/pytorch/pytorch/pull/44288 and fixing the git history.
This allows users to bencmark code without having to specify how long to run the benchmark. It runs the benchmark until the variance (IQR / Median) is low enough that we can be confident in the measurement.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44607
Test Plan: There are unit tests, and we manually tested using Examples posted in git.
Reviewed By: robieta
Differential Revision: D23671208
Pulled By: bitfort
fbshipit-source-id: d63184290b88b26fb81c2452e1ae701c7d513d12
Summary:
Move the timing utils to `torch.utils._benchmark`. I couldn't figure out how to get setuptools to pick it up and put it under `torch` unless it is in the `torch` directory. (And I think it has to be for `setup.py develop` anyway.)
I also modified the record function benchmark since `Timer` and `Compare` should always be available now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41506
Reviewed By: ngimel
Differential Revision: D22601460
Pulled By: robieta
fbshipit-source-id: 9cea7ff1dcb0bb6922c15b99dd64833d9631c37b
Summary:
`HTTPError` are raised when server is overloaded, while `URLError` is
raised when network is not available
And since `HTTPError` is an extension of `URLError`, `URLError` should catch both exceptions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39477
Differential Revision: D21873560
Pulled By: malfet
fbshipit-source-id: 11806671b768705465f562087521ad4887fd20f7
Summary:
Invoke `Popen.communicate` with `timeout` argument and kill the process in `TimeoutExpired` handler
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39191
Differential Revision: D21773510
Pulled By: malfet
fbshipit-source-id: 52b94315f8aa4d6c330dd5c9a8936100e49aef2d
Summary:
This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument.
In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872
Differential Revision: D21740237
Pulled By: mruberry
fbshipit-source-id: acbc027aa1d7877a49664d94db9a5fff91a07042
Summary:
Fixes https://github.com/pytorch/pytorch/issues/38401
* `torch.hub.load_state_dict_from_url()` now also downloads to `$TORCH_HOME/hub/checkpoints` instead of `$TORCH_HOME/checkpoints` like `torch.hub.load()` and others.
* Make `hub_dir` private, add and use `get_dir()` instead.
Also updated docs. Did not see a need for additional unit tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38969
Differential Revision: D21725880
Pulled By: ailzhang
fbshipit-source-id: 58cc6b32ddbda91e58c1c1433cc3916223556ea1
Summary:
This updates assertEqual and assertEqual-like functions to either require both or neither of atol and rtol be specified. This should improve clarity around handling precision in the test suite, and it allows us to remove the legacy positional atol argument from assertEqual. In addition, the "message" kwarg is replace with a kwarg-only "msg" argument whose name is consistent with unittest's assertEqual argument.
In the future we could make "msg" an optional third positional argument to be more consistent with unittest's assertEqual, but requiring it be specified should be clear, and we can easily update the signature to make "msg" an optional positional argument in the future, too.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38872
Differential Revision: D21717199
Pulled By: mruberry
fbshipit-source-id: 9feb856f94eee911b44f6c7140a1d07c1b026d3a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35615
Python 2 has reached end-of-life and is no longer supported by PyTorch.
Now we can clean up a lot of cruft that we put in place to support it.
These changes were all done manually, and I skipped anything that seemed
like it would take more than a few seconds, so I think it makes sense to
review it manually as well (though using side-by-side view and ignoring
whitespace change might be helpful).
Test Plan: CI
Differential Revision: D20842886
Pulled By: dreiss
fbshipit-source-id: 8cad4e87c45895e7ce3938a88e61157a79504aed
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34258
This PR allows both atol and rtol to be specified, uses defaults based on the prior analysis (spreadsheet attached to https://github.com/pytorch/pytorch/pull/32538), but retains the absolute tolerance behavior in cases where precision was previously specified explicitly.
Test Plan: Imported from OSS
Differential Revision: D21110255
Pulled By: nairbv
fbshipit-source-id: 57b3a004c7d5ac1be80ee765f03668b1b13f4a7e
Summary:
TensorBoard tests using SummaryWriter() may fail with a pandas import
complaint if TensorFlow packages are installed in the same python
environment as PyTorch:
Traceback (most recent call last):
File "test_tensorboard.py", line 212, in test_writer
with self.createSummaryWriter() as writer:
File "test_tensorboard.py", line 64, in createSummaryWriter
return SummaryWriter(temp_dir)
...
File "[...]/site-packages/pandas/core/arrays/categorical.py", line 52, in <module>
import pandas.core.algorithms as algorithms
AttributeError: module 'pandas' has no attribute 'core'
The exact failure may depend on the pandas version. We've also seen:
File "[...]/site-packages/pandas/core/arrays/categorical.py", line 9, in <module>
import pandas.compat as compat
AttributeError: module 'pandas' has no attribute 'compat'
The module import chain leading to the failure is tensorboard imports
tensorflow imports tensorflow_estimator imports pandas. pandas includes
a submodule named 'bottleneck', whose name collides with the PyTorch
'test/bottleneck/' subdirectory.
So IF tensorboard, tensorflow, tensorflow_estimator, and pandas are
installed in the python environment AND IF testing is run from within
PyTorch's 'test/' directory (or maybe just with 'test/' in PYTHONPATH,
etc.), then TensorBoard tests using SummaryWriter() will fail.
Rename the 'bottleneck/' directory slightly to avoid the name collision.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29650
Differential Revision: D19698638
Pulled By: ezyang
fbshipit-source-id: cb59342ed407cb37aefc833d67f768a8809129ac
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30445
Create distributed and rpc directories under caffe/test for better management
of unit tests.
Differential Revision: D18702786
fbshipit-source-id: e9daeed0cfb846ef68806f6decfcb57c0e0e3606
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31230
A major issue with distributed autograd currently is that we block an
RPC thread when we call Engine::execute_with_graph_task.
To resolve this issue, I've made modifications to the local autograd engine
such that `execute_with_graph_task` returns a Future instead. The `execute()`
methods for Engine::execute() and DistEngine::execute() still wait() on this
Future which ensures there is no change in behavior yet.
In follow up PRs we can modify the distributed autograd engine to take
advantage of this Future.
Closes#26359
ghstack-source-id: 96298057
Test Plan: waitforbuildbot
Differential Revision: D18999709
fbshipit-source-id: 388f54467fd2415a0acb7df17bd063aedc105229
Summary:
To support variadic inputs of `checkpoint_sequential` was deprecated at https://github.com/pytorch/pytorch/issues/21006. This case should be warned with `DeprecationWarning` for PyTorch 1.2, but it should be simply failed with `TypeError` since PyTorch 1.3. This patch removes the `DeprecationWarning` for PyTorch 1.2.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25985
Differential Revision: D18809875
Pulled By: albanD
fbshipit-source-id: e84dd8629c04979c4b2dc63e8ada94292e8cedd0
Summary:
Resubmit of https://github.com/pytorch/pytorch/pull/25980.
Our old serialization was in tar (like `resnet18-5c106cde.pth` was in this format) so let's only support automatically unzip if checkpoints are zipfiles.
We can still manage to get it work with tarfile, but let's delay it when there's an ask.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26723
Differential Revision: D17551795
Pulled By: ailzhang
fbshipit-source-id: 00b4e7621f1e753ca9aa07b1fe356278c6693a1e
Summary:
This PR does a few small improvements to hub:
- add support `verbose` option in `torch.load`. Note that this mutes hitting cache message but keeps the message of first download as suggested. fixes https://github.com/pytorch/pytorch/issues/24791
- add support loading state dict from tar file or zip file in `torch.hub.load_state_dict_from_url`.
- add `torch.hub.download_url_to_file` as public API, and add BC bit for `_download_url_to_file`.
- makes hash check in filename optional through `check_hash`, many users don't have control over the naming, relaxing this constraint could potentially avoid duplicating download code on user end.
- move pytorch CI off `pytorch/vision` and use `ailzhang/torchhub_example` as a dedicated test repo. fixes https://github.com/pytorch/pytorch/issues/25865
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25980
Differential Revision: D17495679
Pulled By: ailzhang
fbshipit-source-id: 695df3e803ad5f9ca33cfbcf62f1a4f8cde0dbbe
Summary:
This should pass once https://github.com/pytorch/vision/pull/971 is merged.
To remove torchvision as baseline, we just compare to sum of all param.sum() in pretrained resnet18 model, which means we need to manually update the number only when that pretrained weights are changed, which is generally rare.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21132
Differential Revision: D15563078
Pulled By: ailzhang
fbshipit-source-id: f28c6874149a1e6bd9894402f6847fd18f38b2b7
Summary:
I've reported inconsistency between `checkpoint_sequential` and `nn.Sequential` at https://github.com/pytorch/pytorch/issues/19260. Both should provide the same input signature but they don't. I think the consistency is important and I agree with apaszke that `nn.Sequential`'s semantics should be kept instead of `checkpoint_sequential`.
I hope `checkpoint_sequential` raises `TypeError` on variadic arguments since PyTorch 1.2.0. But for now, it's okay just to warn as `DeprecationWarning`. I've talked about this approach with soumith.
Please review this pull request. Any comment will be my pleasure.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21006
Differential Revision: D15530801
Pulled By: soumith
fbshipit-source-id: 0ceb2cc6a17dcc547d0d00ebaf9df8603be53183
Summary:
A few improvements while doing bert model
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19247
Differential Revision: D14989345
Pulled By: ailzhang
fbshipit-source-id: f4846813f62b6d497fbe74e8552c9714bd8dc3c7
Summary:
* `torch.hub.list('pytorch/vision')` - show all available hub models in `pytorch/vision`
* `torch.hub.show('pytorch/vision', 'resnet18')` - show docstring & example for `resnet18` in `pytorch/vision`
* Moved `torch.utils.model_zoo.load_url` to `torch.hub.load_state_dict_from_url` and deprecate `torch.utils.model_zoo`
* We have too many env to control where the cache dir is, it's not very necessary. I actually want to unify `TORCH_HUB_DIR`, `TORCH_HOME` and `TORCH_MODEL_ZOO`, but haven't done it. (more suggestions are welcome!)
* Simplify `pytorch/vision` example in doc, it was used to show how how hub entrypoint can be written so had some confusing unnecessary args.
An example of hub usage is shown below
```
In [1]: import torch
In [2]: torch.hub.list('pytorch/vision', force_reload=True)
Downloading: "https://github.com/pytorch/vision/archive/master.zip" to /private/home/ailzhang/.torch/hub/master.zip
Out[2]: ['resnet18', 'resnet50']
In [3]: torch.hub.show('pytorch/vision', 'resnet18')
Using cache found in /private/home/ailzhang/.torch/hub/vision_master
Resnet18 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
In [4]: model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
Using cache found in /private/home/ailzhang/.torch/hub/vision_master
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18758
Differential Revision: D14883651
Pulled By: ailzhang
fbshipit-source-id: 6db6ab708a74121782a9154c44b0e190b23e8309
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18598
ghimport-source-id: c74597e5e7437e94a43c163cee0639b20d0d0c6a
Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18598 Turn on F401: Unused import warning.**
This was requested by someone at Facebook; this lint is turned
on for Facebook by default. "Sure, why not."
I had to noqa a number of imports in __init__. Hypothetically
we're supposed to use __all__ in this case, but I was too lazy
to fix it. Left for future work.
Be careful! flake8-2 and flake8-3 behave differently with
respect to import resolution for # type: comments. flake8-3 will
report an import unused; flake8-2 will not. For now, I just
noqa'd all these sites.
All the changes were done by hand.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14687478
fbshipit-source-id: 30d532381e914091aadfa0d2a5a89404819663e3
Summary:
Currently, we cannot run a checkpointed function with None argument.
```python
out = torch.utils.checkpoint.checkpoint(run_fn, input_var, None)
```
```
File "/home/tunz/anaconda3/envs/torchdev/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 14, in detach_variable
x = inp.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
```
This PR makes checkpoint function to safely handle None argument.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17969
Differential Revision: D14475148
Pulled By: ezyang
fbshipit-source-id: 9afe9e9aac511a6df1e1620e9ac341536890d451
Summary:
This is the first round of enabling unit tests that work on ROCm 2.1 in my tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16871
Differential Revision: D13997662
Pulled By: bddppq
fbshipit-source-id: d909a3f7dd5fc8f85f126bf0613751c8e4ef949f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14278
In this commit, we make checkpoint_sequential work for models with multiple tensor inputs. Previously, it only processed the first tensor and ignored the rest.
We introduce a new test in test/test_utils.py that replicates the issue referenced in this [GitHub issue](https://github.com/pytorch/pytorch/issues/11093), and we make sure that the test passes by changing the behavior of checkpoint_sequential to process all input tensors.
Reviewed By: ezyang
Differential Revision: D13144672
fbshipit-source-id: 24f58233a65a0f5b80b89c8d8cbced6f814004f7
Summary:
This issue was noticed, and fix proposed, by raulpuric.
Checkpointing is implemented by rerunning a forward-pass segment for each checkpointed segment during backward. This can result in the RNG state advancing more than it would without checkpointing, which can cause checkpoints that include dropout invocations to lose end-to-end bitwise accuracy as compared to non-checkpointed passes.
The present PR contains optional logic to juggle the RNG states such that checkpointed passes containing dropout achieve bitwise accuracy with non-checkpointed equivalents.** The user requests this behavior by supplying `preserve_rng_state=True` to `torch.utils.checkpoint` or `torch.utils.checkpoint_sequential`.
Currently, `preserve_rng_state=True` may incur a moderate performance hit because restoring MTGP states can be expensive. However, restoring Philox states is dirt cheap, so syed-ahmed's [RNG refactor](https://github.com/pytorch/pytorch/pull/13070#discussion_r235179882), once merged, will make this option more or less free.
I'm a little wary of the [def checkpoint(function, *args, preserve_rng_state=False):](https://github.com/pytorch/pytorch/pull/14253/files#diff-58da227fc9b1d56752b7dfad90428fe0R75) argument-passing method (specifically, putting a kwarg after a variable argument list). Python 3 seems happy with it.
Edit: It appears Python 2.7 is NOT happy with a [kwarg after *args](https://travis-ci.org/pytorch/pytorch/builds/457706518?utm_source=github_status&utm_medium=notification). `preserve_rng_state` also needs to be communicated in a way that doesn't break any existing usage. I'm open to suggestions (a global flag perhaps)?
**Batchnorm may still be an issue, but that's a battle for another day.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14253
Differential Revision: D13166665
Pulled By: soumith
fbshipit-source-id: 240cddab57ceaccba038b0276151342344eeecd7