Summary:
# Goals
Do the following things during a distributed backward pass.
1. Accumulate the gradient of a variable to RPC context once the gradient is ready instead of at the very end of the backward pass.
2. Run post/pre hooks installed in`AccumulateGrad` nodes once the gradient is ready for the variable. Currently, the hooks in `AccumulateGrad` are not executed just because the function `AccumulateGrad` itself is not even evaluated by the local engine.
3. Make it extensible to support post hooks installed by DDP's reducer.
# Introduce GradCapturePreHook
## Why do we need this?
### Root issue:
* dist engine uses the autograd.grad-like API on the vanilla engine and then in the Future callback populates the context with the gradients. This is a bad emulation of the .backward() call on the vanilla engine.
### Practical issue:
* The leaf’s hook are not called (because associated with the AccumulateGrad that is not call in the autograd.grad-like API). Modules like DDP rely on these hooks.
* The Future is marked as completed before the context is actually populated with the grads leading to unexpected behavior on the user side.
* The Future callback is only called at the complete end of the backward and so too late for DDP if they want to overlap compute/transfert.
### Proposed solution:
* Provide hooks in the autograd.grad-like API that will allow the distributed engine to populate the context and call the hooks to better emulate the .backward call.
## Who can install a grad capture pre-hook?
This will be an internal hook at C++ level and it won’t be exposed to PyThon code. Only call-sites directly interacting with the local engine can install such hooks.
## Signature
The returned `grad` will be captured.
```
virtual const torch::Tensor& grad operator()(const torch::Tensor& grads) = 0;
```
## Where are hooks installed?
Grad capture pre-hooks are install in GraphTask::ExecInfo::Capture. ExecInfo is per node. Every backward run will have its own GraphTask instance.
## When/How will hooks be called?
When the local engine captures the grads for a node, all grad capture pre hooks are called one by one in the order they are added. The output grads of the hooks will replace the original grads.
The output of the last hook will be used for grad capturing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34501
Test Plan:
All existing tests should pass.
```
python setup.py develop
python test/distributed/rpc/test_dist_autograd_spawn.py DistAutogradTestWithSpawn.test_post_hooks
```
Differential Revision: D20953673
Pulled By: hczhu
fbshipit-source-id: 543b3844823330ea9f9856bab7c5cb2679290a53
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36856
Previously, we could early-exit mark_graph_task_completed() without the future
actually being fully complete - we were only guaranteeing that it was at least
in the process of being marked complete.
This seems to be triggering an assert graph_task->future_result_->completed()
This change simply adds a 1-line waitNoThrow() call to ensure that the future
has been marked complete before exiting the mark_graph_task_completed() function.
The cost is relatively reasonable, since this isn't the common path.
ghstack-source-id: 102423589
Test Plan: buck test mode/dev-nosan caffe2/test/,,,
Differential Revision: D21104121
fbshipit-source-id: 51c1554618880fe80d52d5eb96716abc15f6be8a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36745
As we hold a mutex for our custom C++ Node, when calling reentrant
backward from custom C++ function, we will cocurrently holding many
mutexes up to MAX_DEPTH. TSAN only allow 65 mutexes at once, otherwise
it will complain. This PR lower the limit according to TSAN.
TSAN Reference: https://github.com/google/sanitizers/issues/950
Test Plan: Imported from OSS
Differential Revision: D21072604
Pulled By: wanchaol
fbshipit-source-id: 99cd1acab41a203d834fa4947f4e6f0ffd2e70f2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36640
We had the following race when two threads entered
'mark_graph_task_completed'.
1) Thread 1 grabs the graph_task mutex first and moves captured_vars_ to its
local 'vars'.
2) Thread 1 releases the lock.
3) Thread 2 grabs the mutex and moves an empty captured_vars_ to its local
'vars'.
4) Thread 2 now proceeds to call 'markCompleted' with empty grads.
5) Thread 1 which actually has the right grads never gets to set the grads on
the future since future_completed_ is set to True by Thread 2.
Discovered this while running our RNN example:
https://github.com/pytorch/examples/tree/master/distributed/rpc/rnn and
verified this PR fixes the race.
ghstack-source-id: 102237850
Test Plan: waitforbuildbot
Differential Revision: D21035196
fbshipit-source-id: 1963826194d466b93f19e8016b38e4f9cad47720
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35101
TSAN is noting lock-order-inversion in context of dist autograd because
we're holding lock when GraphTask calls markCompleted() on the relevant futureResult_.
Add an atomic bool to make it possible to protect this without holding the mutex,
and also fix alignment of a few struct vars.
ghstack-source-id: 101805283
Test Plan: buck test mode/opt-tsan //caffe2/test/distributed/rpc:dist_autograd_spawn_thrift
Differential Revision: D20553517
fbshipit-source-id: 446e3718dd68876bd312166ecceed1d92868ce4e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35523
In this PR we extend ThreadLocalState to cover dispatch keys and
ThreadLocalDebugInfo and move it from JIT interpreter down to
thread management (at::launch) and autograd (backward threads) code
Test Plan: unit tests (CI)
Reviewed By: dzhulgakov
Differential Revision: D20615714
fbshipit-source-id: 16a9fc96a25cb6c2629230b1187fbf78786ac565
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35599
We don't check if the ready queue was empty before
https://github.com/pytorch/pytorch/pull/33157 because the CPU worker's
queue might not be empty, but after #33157, we try to check if the owner
thread's ready_queue empty after inline exeuction.
This might not always hold true, imagine the following case:
The CPU thread that calls backward() and the GPU device thread, the Graph is like:
GraphRoot(CPU) -> ComputeNode(GPU)
in both thread_main, they are decrementing `--local_graph_task->outstanding_tasks_` to zero together, and then both thread will enter `if (graph_task_completed(local_graph_task))`, CPU thread will break out and finish and check if local_ready_queue is empty, the GPU thread will send a dummy task to CPU thread ready queue as it think the graph_task finished on its own thread (it actually finished on both threads together). So there will be cases that there's a dummy task remains in the queue.
This happens very rare and non-deterministic, but it might get triggered when we run many jobs in the CI. Remove the check to fix the flakiness
Test Plan: Imported from OSS
Differential Revision: D20739778
Pulled By: wanchaol
fbshipit-source-id: 75a671762650a188f44720625d53f0873617c684
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33157
This PR enables graph level thread parallelism on CPU for the Autograd
Engine. It replace https://github.com/pytorch/pytorch/pull/29574 for the
reason of task level parallelism drawbacks with the existing autograd
system.
Fixes https://github.com/pytorch/pytorch/issues/18333
The graph level parallelism on CPU design:
1. Remove the single CPU thread that init in the Engine itself and allow
the owning thread (which calls Engine::execute) to drive the Engine
execution so that we could let outer threading to enable thread
parallelism.
2. Maintain a separate ReadyQueue per CPU thread, and stash the
ReadyQueue for different devices/threads into the thread local
shared_ptr, the Engine itself will memorize the shared_ptr of the
ReadyQueue to different devices (other than CPU)
3. The CPU thread local ReadyQueue is initialized per CPU thread
Engine::execute call (or `backward()`, `grad()` call), and memorized
the shared_ptr into the GraphTask since every `backward()` call have
its own GraphTask
4. Cross device NodeTask push is accomplished by 2 and 3. we can refer
to device's ReadyQueue from Engine, and CPU's ReadyQueue from
GraphTask, which means if we can push to a different ReadyQueue
according to the device
5. Termination of the CPU thread: if we mark the graph_task as
completed, we will exit the while loop and terminate the current
backward execution, because it's guranteed that all other NodeTasks
is finished before we mark a GraphTask as complete
6. re-entrant thread logic keeps the same, reentrant thread detection is
similar as before, we set the worker_device to NO_DEVICE initially
and set to CPU afterward to detect if this is a reentrant call or not.
7. we still have the reentrant thread pool that create new threads if it's
a deep reentrant case, and reuse the ReadyQueue with the parent thread
for performance.
Since we introduce the thread parallelism on CPU, we have to ensure the
thread safety of the GraphTask. This is not a problem if we execute all
forward in different threads since we will build separate GraphTask in
different threads, and each GraphTask is a separate instance that share
nothing, i.e. Hogwild training on CPU should be fine on this case.
But there might be case that user would like to do some part of the task in
a single thread, and do the rest of work in several threads
concurrently, so thread safety is crucial in those cases. The thread
safety strategy for the multithread autograd is as follows:
1. Add a mutex to protect thread safety in Autograd Node/Function, and
hold the lock for different data racing cases
2. Lock the mutex during Node::apply(), this is to ensure Node that
writing to the shared variable are not racing across threads (i.e.
AccumulateGrad and custom C++ Autograd Node if writing to shared
variables )
3. Lock the mutex during Node::release_variables(), this serve the
purpose that when we release saved_variables from one thread, no
other threads can call the Node::apply(), this ensures the variable
references from other threads aren't dangling.
4. If we don't release any variables and no shared data read/write in
the Node i.e. purely functional, we don't lock the mutex
This way we could protect the thread safety on Autograd Node, but we
could still not protect the thread safety on Node pre/post C++ hooks
(python hooks are automatically thread safe), we rely on the user to
write thread safe C++ hooks if they want the hook to be correctly
applied in multithreading environment.
**User visiable changes**:
There're not too much user visiable changes, since we use the owning
thread to drive the autograd execution, user could write their own
threading code and does not block on the Autograd engine, some behaviors
that user should be aware of:
**Non-determinism**:
if we are calling backward() on multiple thread concurrently but with
shared inputs (i.e. Hogwild CPU training). Since parameters are automatically shared across threads, gradient accumulation might become non-deterministic on backward calls across threads, because two backward calls might access and try to accumulate the same .grad attribute. This is technically not safe, and it might result in racing condition and the result might be invalid to use.
But this is expected pattern if user are using the multithreading
approach to drive the whole training process but using shared
parameters, user who use multithreading should have the threading model
in mind and should expect this to happen. User should use the functional
interface `torch.autograd.grad()` to calculate the gradients instead of
`backward()` on loss.
**Graph retaining**:
If part of the autograd graph is shared between threads, i.e. run first
part of forward single thread, then run second part in multiple threads,
then the first part of graph is shared. In this case different threads execute grad() or backward() on the same graph might
have issue of destroying the graph on the fly of one thread, and the
other thread will crash in this case. We will error out to the user
similar to what call `backward()` twice with out `retain_graph=True`, and let the user know they should use `retain_graph=True`.
**TODOs**:
[ ] benchmark the PR with example models and datasets to demonstrate
the performance gain in CPU training
[ ] ensure that we don't regress the single thread autograd performance
**Follow ups**:
[ ] a correct and tight integration with distributed autograd
[ ] try to unify the thread pool between JIT and Autograd, and see if
there's unifying pattern that we could apply universally
Test Plan: Imported from OSS
Differential Revision: D20236771
Pulled By: wanchaol
fbshipit-source-id: 1e0bd4eec14ffebeffdb60b763b8d6f0e427eb64
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35066Closes#24965
Prior to this commit, final_callbacks_ are cleared on exit of ANY
backward. When using reentrant backward, the last backward would
remove all callbacks from the engine. However, this might lead to
unexpected behavior. For example, the application could install
a final callback after forward, and expecting this callback to fire
when all gradients are ready. If there is a renentrant backward on
a subgraph, it would fire the callback and delete it on exit,
meaning that when fired, not all gradients are ready.
**Failed Attempt**
The 1st attempt was trying to move the callback to the GraphTask
in engine::execute(). However, this failed because more callbacks
could be installed during backward pass.
**Current Solution**
Final callbacks are stored as a member variable in the GraphTask.
* Insertion: use the thread_local current_graph_task to find the
target GraphTask, and append final callback.
* Deletion: final callbacks have the same lifetime as a GraphTask
* Execution: Use the GraphTask provided in the argument to find
final callbacks.
Test Plan: Imported from OSS
Differential Revision: D20546474
Pulled By: mrshenli
fbshipit-source-id: d3f3449bb5af9f8703bcae63e6b52056cd535f11
Summary:
Because `this` must be valid while `Engine::main_thread` is running, at least for non-reentrant worker threads
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34529
Test Plan: Run `test_api --gtest-filter=ModulesTest.InstanceNorm1d` in a loop
Differential Revision: D20552717
Pulled By: malfet
fbshipit-source-id: a0197671db1b7b1499dda675e43e0826f368bf0d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34638
Fixes: https://github.com/pytorch/pytorch/issues/27643
This PR manages notifying workers in the event of a failure during distributed autograd. Gracefully handles propagating errors across all nodes in the backward pass and sets state in the local autograd engines accordingly.
(Note: this ignores all push blocking failures!)
Test Plan: Added 2 new tests checking errors when they are thrown in an intermediate node during distributed autograd. Ensured that all existing distributed autograd tests pass.
Differential Revision: D20164420
fbshipit-source-id: 3d4ed74230969ac70bb763f1b5b1c16d979f66a2
Summary:
Make sure that there could not be more than one instance of either `torch::autograd::Engine` or `torch::autograd::python::PythonEngine`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34567
Test Plan: CI
Differential Revision: D20390622
Pulled By: malfet
fbshipit-source-id: c90595032afc88f552dee52901361b58b282dc1a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33875Fixes#33675.
I added a `current_node_name` argument to AnomalyMetadata::print_stack.
This is a mandatory arg because I found only one callsite and making it
a default arg on a virtual function can be confusing.
Test Plan:
- Tested locally:
https://gist.github.com/zou3519/09937387c83efc76e1700374d5c9c9d9
- I don't know how to add a test for this: the message is printed to
stderr but it isn't an exception nor a warning. I considered capturing
the stderr of a subprocess but that seems like asking for flakiness.
Differential Revision: D20349399
Pulled By: zou3519
fbshipit-source-id: 7585ddffe2bf9e1081f4028a9c44de783978a052
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33214
Distributed autograd had some custom logic in terms of how we
accumulated gradients. This was mostly done early on to enable basic
functionality. Although, in the long term we should merge this logic with what
we have in the local autograd engine. A lot of work has gone into ensuring we
accumulate grads correctly and efficiently and we should reuse that as a
starting point.
We can investigate if we need further custom logic for distributed autograd
later on if we need additional optimizations.
In this PR I've merged the gradient accumulation logic and also the gradient
hooks. As a result, now gradient hooks are called in distributed autograd as
well.
ghstack-source-id: 99838019
Test Plan: waitforbuildbot
Differential Revision: D19843284
fbshipit-source-id: 7923d7e871fb6afd3e98dba7de96606264dcb5f3
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33885Fixes: #32835Fixes: #5834
Can not combine with CUDA's implementation as each of them requires individual `std::once_flag` as well as different `forked_autograd_child` functions. CUDA version relays to python module, autograd uses TORCH_CHECK to report error to python and cpp.
Test Plan: Imported from OSS
Differential Revision: D20144024
Pulled By: VitalyFedyunin
fbshipit-source-id: e7cf30568fff5110e9df7fe5b23f18ed992fa17f
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33156
When dist_autograd_spawn_thrift's 'test_backward_node_failure_python_udf' test is
run, it was encountering a TSAN error related to holding the mutex while the
underlying datastructure was being dealloced.
In this change, we simply get a shared_ptr<> reference to the future, and
set_exception() without having the lock held, to avoid deallocing underneath
the lock.
ghstack-source-id: 98303434
Test Plan: buck test mode/opt-tsan //caffe2/test/distributed/rpc:dist_autograd_spawn_thrift -- 'test_backward_node_failure_python_udf \(test_dist_autograd_spawn\.DistAutogradTestWithSpawn\)'
Differential Revision: D19821362
fbshipit-source-id: 82f735e33f8e608552418ae71592400fa3621e40
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32506
In this PR, we've introduced a `retain_graph` parameter to distributed
autograd similar to `torch.autograd.backward`.
In terms of design, this parameter is sent over RPC to all nodes and is used to
create the GraphTask on the local nodes. This enables us to run
`dist_autograd.backward()` multiple times in the same context.
The use case currently for this is to benchmark only the backward pass for
distributed autograd. We'd like to measure the QPS for the backward pass and as
a result, running a single forward pass and multiple backward passes in a loop
is one way to benchmark backward pass performance.
ghstack-source-id: 97868900
Test Plan: waitforbuildbot
Differential Revision: D19521288
fbshipit-source-id: 7ad8521059fd400d7b5a6ab77ce56e1927ced90a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31995Fixes#31906.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Differential Revision: D19331259
Pulled By: ezyang
fbshipit-source-id: 5d24bf3555e632211a9b6f8e50ff241603c18b3d
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31508
This PR builds on top of https://github.com/pytorch/pytorch/pull/31230
to ensure that distributed autograd doesn't block an RPC thread anymore during
the backward pass.
I've also added a unit test where all ranks hammer rank 0 without about 60
backward calls (which would cause a deadlock earlier), but now such a test
passes without any issues.
ghstack-source-id: 96345097
Test Plan: waitforbuildbot
Differential Revision: D19188749
fbshipit-source-id: b21381b38175699afd0f9dce1ddc8ea6a220f589
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31909https://github.com/pytorch/pytorch/pull/31230 introduced a bug where
we would end up calling `graph_task_post_processing` twice for reentrant
backward calls (once when we mark the future completed and then we we called
graph_task_post_processing in execute_with_graph_task).
This PR fixes the issues by verifying the future we return in that case is
completed and we remove the call to graph_task_post_processing.
In addition to that I added a test that reproduced the problem and verified it
is fixed by this PR.
ghstack-source-id: 96349102
Test Plan: waitforbuildbot
Differential Revision: D19296363
fbshipit-source-id: dc01a4e95989709ad163bb0357b1d191ef5a4fb2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31230
A major issue with distributed autograd currently is that we block an
RPC thread when we call Engine::execute_with_graph_task.
To resolve this issue, I've made modifications to the local autograd engine
such that `execute_with_graph_task` returns a Future instead. The `execute()`
methods for Engine::execute() and DistEngine::execute() still wait() on this
Future which ensures there is no change in behavior yet.
In follow up PRs we can modify the distributed autograd engine to take
advantage of this Future.
Closes#26359
ghstack-source-id: 96298057
Test Plan: waitforbuildbot
Differential Revision: D18999709
fbshipit-source-id: 388f54467fd2415a0acb7df17bd063aedc105229
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30642
Adding a couple of basic metrics for distributed autograd which would
help in determining stuckness.
ghstack-source-id: 95156189
Test Plan: waitforbuildbot
Differential Revision: D18776478
fbshipit-source-id: a0556ad6fe2b7c3cd0082ee2350c1c78cafaaec5
Summary:
Fixes https://github.com/pytorch/pytorch/issues/29161.
I looked a bit at the code changes related to this and think I have all of the use cases of `DeprecatedTypeProperties` covered in the message, but suggestions from someone with more context on this would be very much appreciated :)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30281
Differential Revision: D18830818
Pulled By: ezyang
fbshipit-source-id: 1a7fcee15354ae09e6644577e7fa33bd26acfe20
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27940
1) If we receive an error for outstanding rpcs, we enqueue an appropriate error
on the local autograd engine.
2) Add an `exit_on_error` mode for the local autograd engine, where the
computation stops if we see an error.
ghstack-source-id: 92603377
Test Plan: Added unit tests to test failures.
Differential Revision: D17916844
fbshipit-source-id: 199a7832f1033c36a9bbcc1e80d86576c04965d0
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022
This change implements the "FAST" mode distributed autograd backward
pass as described in https://github.com/pytorch/pytorch/issues/23110.
At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.
We have made the following changes to the local autograd engine for this
purpose:
1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.
In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926
Test Plan: unit tests.
Differential Revision: D17652615
fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
Summary:
This PR addresses issue https://github.com/pytorch/pytorch/issues/7601.
Currently models that use streams explicitly in forward have to do a lot of extra work to make backwards respect those streams. This PR extends the (recently added) input tracing (see TypeAndShape) to record the devices and streams of inputs. The autograd engine then uses this metadata to enact the expected stream parallelism without extra work from the user.
For example, a model with forward declared like (original example courtesy of ngimel):
```
def forward(self,x):
x0 = x.clone()
torch._C._cuda_setStream(self.stream1._cdata)
y0 = self.fc1(x0)
self.event1.record(stream = torch.cuda.current_stream())
torch._C._cuda_setStream(self.stream2._cdata)
y1 = self.fc2(x)
self.event2.record(stream = torch.cuda.current_stream())
self.stream2.wait_event(self.event1)
return y0 + y1
```
currently will backward on a single stream. With this change the kernels will go on the streams they are assigned in forward and both forward and backward will (for appropriate sizes) run the fc1 and fc2 kernels simultaneously.
The crux of this change is, as mentioned, an expansion of the TypeAndShape tracing and a relatively simple change to the autograd engine to use cuda events for stream synchronization. To make this efficient I also added a new AutoGPUAndStream class, exposed getting and setting streams on devices, and removed InputBuffer's AutoGPU (it's now redundant). While making these modifications I also fixed AutoGPU to check before setting the GPU when it's destroyed and to use THCudaCheck instead of its custom error handler. These changes mean that an often excessive cudaSetDevice() is not being called when inputs are added to a buffer.
In addition to allowing users to easily set and use streams that are respected in both forward and backward, this change may encourage modules to do the same and the expanded tracing might allow further optimizations in the autograd engine. (apaszke, for example, now after initial enumeration we know the number of devices that will be used by a graph task, which might help provide a sense of the "level of parallelism" we should expect.)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8354
Test Plan: Two tests were added specifically for this behavior.
Differential Revision: D17275980
Pulled By: mruberry
fbshipit-source-id: 92bd50ac782ffa973b159fcbbadb7a083802e45d
Summary:
Improve handling of mixed-type tensor operations.
This PR affects the arithmetic (add, sub, mul, and div) operators implemented via TensorIterator (so dense but not sparse tensor ops).
For these operators, we will now promote to reasonable types where possible, following the rules defined in https://github.com/pytorch/pytorch/issues/9515, and error in cases where the cast would require floating point -> integral or non-boolean to boolean downcasts.
The details of the promotion rules are described here:
https://github.com/nairbv/pytorch/blob/promote_types_strict/docs/source/tensor_attributes.rst
Some specific backwards incompatible examples:
* now `int_tensor * float` will result in a float tensor, whereas previously the floating point operand was first cast to an int. Previously `torch.tensor(10) * 1.9` => `tensor(10)` because the 1.9 was downcast to `1`. Now the result will be the more intuitive `tensor(19)`
* Now `int_tensor *= float` will error, since the floating point result of this operation can't be cast into the in-place integral type result.
See more examples/detail in the original issue (https://github.com/pytorch/pytorch/issues/9515), in the above linked tensor_attributes.rst doc, or in the test_type_promotion.py tests added in this PR:
https://github.com/nairbv/pytorch/blob/promote_types_strict/test/test_type_promotion.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22273
Reviewed By: gchanan
Differential Revision: D16582230
Pulled By: nairbv
fbshipit-source-id: 4029cca891908cdbf4253e4513c617bba7306cb3
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22397
Test Plan:
Added test for reentrant backwards with checkpoint and a test for a recursive backwards function (which should fail if we run all the reentrant tasks recursively in the same thread) and for testing priority of reentrant tasks.
~~Will add a test for priority of reentrant tasks in future pr.~~
Imported from OSS
Differential Revision: D16131955
fbshipit-source-id: 18301d45c1ec9fbeb566b1016dbaf7a84a09c7ac
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17991
changes:
-Breaks bc: Tensor::type() now returns DeprecatedTypeProperties& rather than Type&.
-Added DeprecatedTypeProperties, it serves as a temporary replacement for Type as the return value of Tensor::type(). This contributes to making Type just for dispatch purposes so that we can make it dtype agnostic.
-Tensor::dispatch_type() now returns Type& like Tensor::type() used to do.
-Changed callsites of Tensor::type() appropriately.
Reviewed By: ezyang
Differential Revision: D14443117
fbshipit-source-id: 239ccb7a09626279a71d1a37f8f82e7f57bf7d9e
Summary:
Allow the comparison function used in ReadyQueue to handle the empty FunctionTasks created by the reentrant autograd.
Fix#11732
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15791
Differential Revision: D13598006
Pulled By: soumith
fbshipit-source-id: 0bfdf28a735fbfe44f0fdbaf8b74a6198e6a1984
Summary:
This PR adds the final set of clang-tidy checks we should add for our codebase: a last set of performance-related checks. Most fixes here are around changing `auto` to `const auto&` in a few places where unnecessary copies were made, and adding `reserve()` calls before loops doing repeated `push_back()`. Also a few cases of calling `std::string::find` with a single-character string literal instead of a single char, which uses a less efficient string search algorithm meant for searching larger substrings.

ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15198
Differential Revision: D13468797
Pulled By: goldsborough
fbshipit-source-id: 2bed1ea1c7c162b7f3e0e1026f17125e88c4d5b2
Summary:
This PR fixes around 250 places in the codebase where we were making unnecessary copies of objects (some large, some small).
ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15026
Differential Revision: D13458784
Pulled By: goldsborough
fbshipit-source-id: be5148b2ce09493588d70952e6f6d6ff5ec5199b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14248
This diff also introduces a horrifying hack to override CUDA's DeviceGuardImpl
with a HIPGuardImplMasqueradingAsCUDA, to accommodate PyTorch's current
behavior of pretending CUDA is HIP when you build with ROCm enabled.
Reviewed By: bddppq
Differential Revision: D13145293
fbshipit-source-id: ee0e207b6fd132f0d435512957424a002d588f02
Summary:
```
This diff changes the HIPification of ATen to be out-of-place.
We now have the following mappings:
- ATen/cuda => ATen/hip
- ATen/native/cuda => ATen/native/hip
- ATen/native/sparse/cuda => ATen/native/sparse/hip
- THC => THH
- THCUNN => THHUNN
The build system is adjusted to know about these new build paths,
and HIPify is taught how to adjust include paths and
THC_GENERIC_FILE appropriately. ATen_hip is now built as
the ATen_hip library, rather than reusing ATen_cuda.
However, despite these new filepaths, none of the identifiers in ATen
have actually changed. So, e.g., THHGeneral.h still defines functions
named THC_blahblah, and HIP still shows up as CUDA in PyTorch itself.
We'll tackle this in a subsequent PR; this diff is just to get the files
out-of-place.
Minor extra improvements:
- Don't edit tmp_install when hipifying
- HIP no longer builds native_cudnn_cpp; it was unnecessary
- Caffe2_HIP_INCLUDES is now Caffe2_HIP_INCLUDE, for consistency
with all the other variables.
- HIP build now properly respects ATEN_CUDA_FILES_GEN_LIB (it
did not previously.)
- You can now override file extension matching in pyHIPIFY
by explicitly specifying its full name in the matching list.
This is used so we can HIPify CMakeLists.txt in some situations.
A little bit of string and ceiling wax:
- gen.py grows a --rocm flag so that it knows to generate CUDA
files which actually refer to the HIP headers (e.g., THH.h)
We'll get rid of this eventually and generate real HIP files,
but not for this PR.
- Management of HIP dependencies is now completely deleted
from the ATen CMakeLists.txt. The old code was dead (because
it was shoveled in ATen_CUDA_DEPENDENCY_LIBS and promptly
ignored by the Caffe2 build system) and didn't actually work.
```
Stacked on https://github.com/pytorch/pytorch/pull/14849 review last commit only
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14866
Differential Revision: D13419475
Pulled By: ezyang
fbshipit-source-id: cb4c843df69a1d8369314c9fab1b7719520fa3db
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.
I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.
I used the following script to do the canonicalization:
```
import subprocess
import re
import os.path
files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
for fn in files:
if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
continue
if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
continue
with open(fn, 'r') as f:
c = f.read()
def fmt(p):
return "#include <{}>".format(p)
def repl(m):
p = m.group(1)
if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
return fmt(p)
if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
return fmt(p)
for root in ["aten/src", "torch/lib", ""]:
for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
new_p = os.path.relpath(os.path.join(bad_root, p), root)
if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
return fmt(new_p)
print("ERROR: ", fn, p)
return m.group(0)
new_c = re.sub(r'#include "([^"]+)"', repl, c)
if new_c != c:
print(fn)
with open(fn, 'w') as f:
f.write(new_c)
```
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849
Reviewed By: dzhulgakov
Differential Revision: D13363445
Pulled By: ezyang
fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
Summary:
Previously symbolic AD formulas assumed that no broadcasting happened,
and would return gradients of incorrect shapes (possibly leading to
silent errors later).
Fixes a few bugs (known and unknown):
- #11736
- ArgumentSpec didn't compute the input types correctly [(it didn't advance the offset for non-tensor args)](https://github.com/pytorch/pytorch/pull/14485/files#diff-4fd3157a056596aefb8cdf41022a208bR153)
- Symbolic AD could suffer from use after free (dangling pointers in grad map), because [`EliminateDeadCode` could have removed nodes](https://github.com/pytorch/pytorch/pull/14485/files#diff-25d33ad1ed6855684dec79d927ca6142L781) that referenced gradients of certain values.
- Undefined behavior in `aten::size`
During my tests I've also found a few new problems, and I have opened issues for them:
- FusionGroup seems to think that cat nodes broadcast their inputs (#14483)
- `prim::ConstantChunk` derivative formula doesn't handle undefined inputs (#14484)
This patch unfortunately deoptimizes some of our code (Fusion doesn't happen past chunk nodes, and outputs more tensors only because we have to get their size). I know how to fix those issues, but wanted to fix this terrible bug quickly.
cc zou3519 zdevito ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14485
Reviewed By: eellison
Differential Revision: D13312888
Pulled By: suo
fbshipit-source-id: ad46bfb4d0a306ad9451002f8270f7a790f72d58
Summary:
Previously symbolic AD formulas assumed that no broadcasting happened,
and would return gradients of incorrect shapes (possibly leading to
silent errors later).
Fixes a few bugs (known and unknown):
- #11736
- ArgumentSpec didn't compute the input types correctly [(it didn't advance the offset for non-tensor args)](https://github.com/pytorch/pytorch/pull/14485/files#diff-4fd3157a056596aefb8cdf41022a208bR153)
- Symbolic AD could suffer from use after free (dangling pointers in grad map), because [`EliminateDeadCode` could have removed nodes](https://github.com/pytorch/pytorch/pull/14485/files#diff-25d33ad1ed6855684dec79d927ca6142L781) that referenced gradients of certain values.
- Undefined behavior in `aten::size`
During my tests I've also found a few new problems, and I have opened issues for them:
- FusionGroup seems to think that cat nodes broadcast their inputs (#14483)
- `prim::ConstantChunk` derivative formula doesn't handle undefined inputs (#14484)
This patch unfortunately deoptimizes some of our code (Fusion doesn't happen past chunk nodes, and outputs more tensors only because we have to get their size). I know how to fix those issues, but wanted to fix this terrible bug quickly.
cc zou3519 zdevito ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14485
Differential Revision: D13280899
Pulled By: soumith
fbshipit-source-id: 80cc5ec9331be80e1bb9ddfe85b81c2b997e0b0c
Summary:
Rebased version of https://github.com/pytorch/pytorch/pull/13337.
I don't think the lint errors in the original PR had to do with files I touched, so hopefully the rebase fixes them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14587
Differential Revision: D13277428
Pulled By: soumith
fbshipit-source-id: f04c186b1dd4889b4250597eef87f9e9bf7b2426
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13342
This PR introduces a few new concepts:
- DeviceGuardImplInterface, and implementations for CPU and CUDA, which
provide a generic interface for interfacing with device and stream state,
without requiring a direct dependency on the code in question.
- InlineDeviceGuard, a general template for generating both specialized
and dynamically dispatched device guard implementations. Dynamic
dispatch is done by specializing it on a VirtualGuardImpl.
- Provide a device-independent DeviceGuard class, which can be used even
from CPU code. It uses the aforementioned dynamic dispatch.
- CUDA-specialized CUDAGuard class, which doesn't have a dynamic dispatch
but can only be used from CUDA.
- StreamGuard, which is the same as above, but for streams rather than
devices.
- Optional variants of all the aforementioned guards, which are a no-op if
no device/stream is specified
- CUDAMultiStreamGuard, specifically for the case when we want to set
a device on every guard.
There are some subtle semantic changes, which have been thoroughly documented
in the class definition.
BC-breaking changes:
- Move constructor/assignment have been removed from all device guard
implementations.
- In some cases where you previously wrote 'set_device' (or 'set_stream'), you now must write
'reset_device', because if you switch devices/device types, the stream/device on the
previous device is unset. This is different from previous behavior.
- CUDAGuard no longer handles streams, or multiple streams. Use CUDAStreamGuard
or CUDAMultiStreamGuard as appropriate for your use case.
Reviewed By: dzhulgakov
Differential Revision: D12849620
fbshipit-source-id: f61956256f0b12be754b3234fcc73c2abc1be04e
Summary:
Enables almost all `modernize-*` checks in clang-tidy. This warns against things such as:
- Use of `const std::string&` instead of new-style `std::string` + move,
- Using old-style loops instead of range-for loops,
- Use of raw `new`
- Use of `push_back` instead of `emplace_back`
- Use of `virtual` together with `override` (`override` is sufficient)
ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13196
Differential Revision: D12891837
Pulled By: goldsborough
fbshipit-source-id: 4d0f782a09eb391ee718d3d66f74c095ee121c09
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13232
DeviceGuard should be device agnostic, which means that it shouldn't
assume that int64_t means select the CUDA device.
Reviewed By: gchanan
Differential Revision: D10858024
fbshipit-source-id: b40e8337e4046906fd8f83a95e6206367fb29dbe
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12792
This is a follow up diff after D10238910.
Only non-codemod change is the removal of ATen/Error.h and ATen/core/Error.h. Other files are basically changing the inclusion path + clang format for inclusion order.
Reviewed By: bddppq
Differential Revision: D10437824
fbshipit-source-id: 7f885f80ab5827468d1351cfb2765d0e3f555a69
Summary:
Linting `torch/csrc/` (non-recursive) and `torch/csrc/autograd` (non-recursive).
Fixed things like:
- `typedef` vs `using`
- Use `.empty()` instead of comparing with empty string/using `.size() == 0`
- Use range for loops instead of old style loops (`modernize-`)
- Remove some `virtual` + `override`
- Replace `stdint.h` with `cstdint`
- Replace `return Type(x, y)` with `return {x, y}`
- Use boolean values (`true`/`false`) instead of numbers (1/0)
- More ...
ezyang apaszke cpuhrsch
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11050
Differential Revision: D9597505
Pulled By: goldsborough
fbshipit-source-id: cb0fb4793ade885a8dbf4b10484487b84c64c7f2
Summary:
This PR extends the existing type and shape metadata tracing and verification done in autograd with device information. This expansion of tracing is required for #8354, is likely useful in other scenarios, and is a healthy sanity check, just like type and shape tracing.
The precise changes are:
- TypeAndShape -> InputMetadata, now includes device()
- Creating InputMetadata is simplified to just require a tensor, and callers were updated to use this simpler invocation wherever possible
- The gradient accumulator of a variable is now reset when set_data() is called if either the type or device changes, and this reset now locks to avoid contention with acquiring the gradient accumulator
- Mismatched devices during backward() will throw a runtime error, just like mismatched type and shape
- (Bonus!) Two uninitialized pointers in THCReduce are now initialized (to nullptr) to prevent build warnings
fyi colesbury
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9796
Reviewed By: goldsborough
Differential Revision: D9119325
Pulled By: ezyang
fbshipit-source-id: 76d1861b8d4f74db0575ff1f3bd965e18f9463de
Summary:
More clang tidy cleanups in `torch/csrc`. This time:
1. `hicpp-use-equals-default` recommends `= default` instead of `{}` for constructors/destructors. This is better practice because it expresses the intent better (https://stackoverflow.com/questions/6502828/what-does-default-mean-after-a-class-function-declaration)
2. `readability-inconsistent-declaration-parameter-name` enforces that parameter names in the declaration match parameter names in the definition. This is just generally useful and can prevent confusion and bugs.
Also updated my script a little bit.
apaszke ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9737
Differential Revision: D9069069
Pulled By: goldsborough
fbshipit-source-id: f7b3f3a4eb4c9fadc30425a153566d3b613a41ae
Summary:
```
This adds TensorIterator, a helper class for computing element-wise
operations that's intended to replace the CPU and CUDA apply utils
functions.
CPU kernels are implemented as functions that operate on strided 1-d
tensors compared to CPUApplyUtils which operated individual elements. This
allows the kernels to handle vectorization, while TensorIterator handles
parallelization and non-coalesced dimensions.
GPU kernels continue to operate on elements, but the number of
specializations is reduced. The contiguous case remains the same. The
non-contiguous case uses a single (reduced) shape for all operands and
the fast integer division from THCIntegerDivider. To avoid extra
specializations for indexing with 64-bits, large operations are split
into smaller operations that can be indexed with 32-bits.
Major semantic changes:
- No more s_add, s_mul, s_div, or s_sub. Broadcasting is handled by
TensorIterator. The autograd engine performs the reduction assuming
standard broadcasting if the gradient shape does not match the
expected shape. Functions that do not use standard broadcasting rules
should either continue to trace the expand calls or handle the
reduction in their derivative formula.
- Use ONNX v7, which supports broadcasting ops.
Performance impact:
- Small increased fixed overhead (~0.5 us)
- Larger overhead for wrapped numbers (~2.5 us)
- No significant change for ops on contiguous tensors
- Much faster worst-case performance for non-contiguous GPU tensors
- Faster CPU bias addition (~2x)
- Faster GPU bias addition (~30% faster)
Future work:
- Decrease overhead, especially for wrapping numbers in Tensors
- Handle general inter-type operations
- Extend to unary ops and reductions
- Use buffering for compute-bound operations on non-contiguous tensors
(pull in from CPUApplyUtils)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/8919
Differential Revision: D8677600
Pulled By: colesbury
fbshipit-source-id: 61bc9cc2a36931dfd00eb7153501003fe0584afd
* Created TensorOptions
Storing the type in TensorOptions to solve the Variable problem
Created convenience creation functions for TensorOptions and added tests
Converted zeros to TensorOptions
Converted rand to TensorOptions
Fix codegen for TensorOptions and multiple arguments
Put TensorOptions convenience functions into torch namespace too
All factory functions except *_like support TensorOptions
Integrated with recent JIT changes
Support *_like functions
Fix in place modification
Some cleanups and fixes
Support sparse_coo_tensor
Fix bug in Type.cpp
Fix .empty calls in C++ API
Fix bug in Type.cpp
Trying to fix device placement
Make AutoGPU CPU compatible
Remove some auto_gpu.h uses
Fixing some headers
Fix some remaining CUDA/AutoGPU issues
Fix some AutoGPU uses
Fixes to dispatch_tensor_conversion
Reset version of new variables to zero
Implemented parsing device strings
Random fixes to tests
Self review cleanups
flake8
Undo changes to variable.{h,cpp} because they fail on gcc7.2
Add [cuda] tag to tensor_options_cuda.cpp
Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks
Fix linker error in AutoGPU.cpp
Fix bad merge conflict in native_functions.yaml
Fixed caffe2/contrib/aten
Fix new window functions added to TensorFactories.cpp
* Removed torch::TensorOptions
Added code to generate wrapper functions for factory methods
Add implicit constructor from Backend to TensorOptions
Remove Var() from C++ API and use torch:: functions
Use torch:: functions more subtly in C++ API
Make AutoGPU::set_device more exception safe
Check status directly in DynamicCUDAHooksInterface
Rename AutoGPU to DeviceGuard
Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad
remove python_default_init: self.type()
Add back original factory functions, but with deprecation warnings
Disable DeviceGuard for a couple functions in ATen
Remove print statement
Fix DeviceGuard construction from undefined tensor
Fixing CUDA device compiler issues
Moved as many methods as possible into header files
Dont generate python functions for deprecated factories
Remove merge conflict artefact
Fix tensor_options_cuda.cpp
Fix set_requires_grad not being checked
Fix tensor_new.h
TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac
Fix bug in DeviceGuard.h
Missing includes
TEMPORARILY moving a few more methods into .cpp to see if it fixes windows
Fixing linker errors
* Fix up SummaryOps to use new factories
Undo device agnostic behavior of DeviceGuard
Use -1 instead of optional for default device index
Also move DeviceGuard methods into header
Fixes around device index after optional -> int32_t switch
Fix use of DeviceGuard in new_with_tensor_copy
Fix tensor_options.cpp
* Fix Type::copy(
* Remove test_non_float_params from ONNX tests
* Set requires_grad=False in ONNX tests that use ints
* Put layout/dtype/device on Tensor
* Post merge fixes
* Change behavior of DeviceGuard to match AutoGPU
* Fix C++ API integration tests
* Fix flip functions
* Factor python dependency out of interpreter
* Remove NO_PYTHON for the autograd engine
If there is no python bindings, then a default Engine is constructed
the first time it is requested.
If the python libraries are loaded, then they override the default
accessor and the default engine becomes a python Engine.
Note: it is possible for two engines to be generated if a non-python
one gets created before the python bindings are loaded. This case
is rare, and just results in additional threads being spawned.
* Fixing AlexNet test which is skipped in CI
* Add backward() to Tensor and Variable
* Add at:: in front of Tensor
* Trying to not move optional to appease windows?
* Move implementation into cpp file
* Undo some formatting changes
* Makes accumulate_grad functions high priority in backwards passes
* Delegating constructor and comments
* Sequence_nr ain't pretty no more
* Sequence_nr ain't pretty no more
* Autograd container for trading compute for memory
* add a unit test for checkpoint
* address comments
* address review comments
* adding some docs for the checkpoint api
* more comments
* more comments
* repro bug
* Fix a subtle bug/apply some review comments
* Update checkpoint.py
* Run everything in grad mode
* fix flake and chunk=1
* use imperative backward as per discussion
* remove Variable and also add models and test for models
* Add a simple thread local variable to check for autograd grad mode
* remove models and models test after debugging
* address review comments
* address more comments
* address more comments
This PR adds the possibility to build the C++ parts of autograd and jit, with no dependency on Python.
The goal is to allow taking a PyTorch IR representation (a tree s-expr) and running it with provided inputs.
Prerequisite: build PyTorch so that codegen runs once.
Instructions:
cd tools/cpp_build
bash build_all.sh
This will build libtorchjit and torchjit_test in tools/cpp_build/build/torchjit-build. The latter basically runs the code in test_jit.cpp for now.
While writing the PR, it turned out that a few of Python.h includes were redundant. They were removed here (PyTorch tests still pass on my machine, we'll see CI).
* Introduce Python-free builds of autograd and jit
* Remove NO_PYTHON ifdef in functions/special
* Improve Function interface
* Undo tracer changes
* Fix bug in VariableType.set_history
* Rename function_counter and sequence_number to sequence_nr
* Clarify Function documentation
* Replace swap_next_edges with next_edges() getter
* Bring back set_gradient_edge
* Simplify special.cpp
* add_gradient_edge -> create_gradient_edge
* Add mutable getters for pre/post hooks
* Use make_variable with Edge
* Remove remove_gradient_edge in favor of detach_
* Fix documentation and remove create_gradient_edge friend method
* Canonicalize some includes
Previously the side-effect free grad calculation was performed
using callbacks that could also override the decision to run a
function. However this had a few problems e.g. it forced us to iterate
over pretty much all functions in the graph and drop their buffers.
This patch improves the mechanism, by adding explicit support for this
kind of evaluation in execute(). It's safer, and the algorithm used to
decide which nodes have to be evaluated was replaced with a faster one.
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().
In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()
Fixes#3627
* Add a JIT interpreter
The separate interpreter is used to graphs with a lower overhead than
converting them to autograd graphs. Some notes:
* does not support Handles/PythonOp/CppOp, these will be in a future commit
* jit_closure.cpp still exists and we fall back to it for now when
cannot handle something because of PythonOp/CppOp
* In order to support retain_graph=True, the interpreter can be cloned,
creating a copy that can be run with different arguments. This is
assumed to be the non-standard case so cloning is not particularly optimized.
No tensor _data_ is copied, but the at::Tensor list in the interpreter is.
If we hit problems, there is a lot we could do (such as register allocation)
to minimize the stuff that needs to be copied.
* Uses a pImpl pattern to keep implementation details out of its header file.
* Modifies the way getTensorOp works so that it reads/writes to already-existing
vectors, this prevents needing to realloc these buffers each time.
* Timings are here: https://gist.github.com/zdevito/5a20ac29fb1b9e449e693b67dc478127
This reduces overhead to about the same as running it in python.
It is about 10us faster to run the same thing using ATen directly.
* Code Mod
Interpreter -> InterpreterState
Function -> Code
Add other requested comments.
* RegList -> ListHandle<T>
Change the RegList functions to be safer by identifying the type of
each argument list, and checking that list insert does not try
to add to two different lists at once.
* Use exactly equal for interp tests
Otherwise, on many machines, the size of the OpenMP thread pool will
change between MKL and our OpenMP enabled functions. The constant thread
creation and destruction results in worse performance and leaks memory
on GCC 5.4
* A pile of misc doc fixes.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
* Handle @apaszke review comments.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
* Initial csrc documentation.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>