Commit Graph

6 Commits

Author SHA1 Message Date
Michael Lazos
49df1de383 Cudagraphs support for compiled optimizers (#107504)
Marks all params/optimizer state as static addresses and a finalizer which cleans up the graph attributes when the optimizer goes out of scope.

**Note: this does not mark grads as static because this will increase memory usage significantly

There are two cases:
1. The upstream graph is cudagraphed - this case will work fine OOTB
2. The upstream graph is not cudagraphed - in this case, there will be a lot of copies introduced from the upstream (to copy the grads) into cudagraphed-owned memory, unless the user explicitly marks the grads as static. If the user does this, this will also require not deallocating the grads in zero_grad() (either the mod or optimizer version) by setting them to zero vs None. There is a PR (https://github.com/pytorch/pytorch/pull/107853) in flight to throw an error if zero_grad attempts to set static grads to None.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107504
Approved by: https://github.com/eellison
2023-08-31 20:47:18 +00:00
Michael Lazos
690ea933ca Enable more e2e foreach optimizer compilation tests (#105438)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105438
Approved by: https://github.com/jansel
2023-07-20 02:41:19 +00:00
Michael Lazos
a290cbf32b Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
2023-07-05 23:40:03 +00:00
Michael Lazos
5a97c947c6 Fix optimizer grad mode state interaction with dynamo (#103952)
Graph break before restoring the grad mode to ensure dynamo respects `no_grad`. This isn't a bug necessarily, but this will allow us to get good perf until aot is updated.

https://github.com/pytorch/pytorch/issues/104053

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103952
Approved by: https://github.com/janeyx99
2023-06-23 02:07:08 +00:00
Michael Lazos
05e91a50d9 Manually generate guards for optimizer (#103121)
Manually generate guards for optimizer rather than use variable builder, which can be slow with lots of params.

This is the reason for ~10s compile slowdown

Redisable `_init_group`. This is important, because if for any reason a frame which calls `_init_group` is run in the python interpreter, we will trace it, which we don't want to do. We only want to call it when it is accessed via the fast path implemented with the optimizer variable during symbolic interpretation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103121
Approved by: https://github.com/jansel
2023-06-08 21:45:19 +00:00
Michael Lazos
c46af25bb3 Initialize optimizer in dynamo to avoid graph break and tracing slowness (#102640)
On calls to `_init_group` rather than tracing through it, extract python values from the arguments, and call the initialization. This avoids having to trace this function which is very slow with large parameters, and also avoids graph breaking on it. This is sound in this case because the state is only initialized once in the eager case. Guards on the state and params are generated explicitly rather than via tracing the initialization.

Caveats:
`_init_group` also gathers various state tensors into lists via mutating list arguments to pass to the functional optimizer implementation. These state tensors exist on the optimizer itself, but we don't know exactly how the gathering is done and which tensors correspond to which attributes of the optimizer module (each optimizer has different states). To rectify this, we keep weak_ptrs to all of the tensors collected in the lists in globals (similar to how parameter keys are stored for dictionaries). These pointers are guaranteed to be alive as long as the optimizer object is alive if the internal state is not interfered with and they are guarded with weakref guards

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102640
Approved by: https://github.com/jansel
2023-06-03 15:49:51 +00:00