pytorch/tools/code_analyzer
Michael Carilli bbc3cc6718 [CUDA graphs] [BC-breaking] Makes torch.cuda.amp.GradScaler scale updates in-place for better composability with graph capture (#55562)
Summary:
I'd like the following pattern (a natural composition of Amp with full fwd+bwd capture) to work:
```python
# Create "static_input" with dummy data, run warmup iterations,
# call optimizer.zero_grad(set_to_none=True), then
g = torch.cuda._Graph()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    optimizer.zero_grad(set_to_none=True)
    g.capture_begin()
    with autocast():
        out = model(static_input)
        loss = loss_fn(out)
    scaler.scale(loss).backward()
    g.capture_end()
torch.cuda.current_stream().wait_stream(s)

# Training loop:
for b in data:
    # optimizer.zero_grad() deliberately omitted, replay()'s baked-in backward will refill statically held .grads
    static_input.copy_(b)
    g.replay()
    scaler.step(optimizer)
    scaler.update()
```

Right now `GradScaler` can't work with this pattern because `update()` creates the scale tensor for the next iteration out of place. This PR changes `update()` to act in place on a long-lived scale tensor that stays static across iterations.

I'm not sure how this change affects XLA (see https://github.com/pytorch/pytorch/pull/48570), so we shouldn't merge without approval from ailzhang yaochengji.

Tagged bc-breaking because it's a change to the amp update utility function in native_functions.yaml. The function was never meant to be user-facing though.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/55562

Reviewed By: zou3519

Differential Revision: D28046159

Pulled By: ngimel

fbshipit-source-id: 02018c221609974546c562f691e20ab6ac611910
2021-04-30 13:03:05 -07:00
..
analyzer.cpp
build.sh [PyTorch] [BUCK] Replace pt_deps.bzl with a YAML operator dependency file which is generated by the code analyser (#46057) 2020-10-20 02:00:36 -07:00
CMakeLists.txt [pytorch] remove unused flags from code analyzer & move format support to python (#37393) 2020-04-28 17:16:55 -07:00
default_op_deps.yaml [CUDA graphs] [BC-breaking] Makes torch.cuda.amp.GradScaler scale updates in-place for better composability with graph capture (#55562) 2021-04-30 13:03:05 -07:00
gen_op_registration_allowlist.py Make gen_op_registration flake8 compliant (#47604) 2020-11-09 08:31:07 -08:00
op_deps_pass.cpp Reimplement per-operator selective build (#39401) 2020-08-20 19:10:02 -07:00
op_deps_processor.py [pytorch] clean up unused util srcs under tools/autograd (#50611) 2021-01-18 23:54:02 -08:00
run_analyzer.sh Remove .impl_UNBOXED() and functionalities associated with it (#49220) 2021-01-06 14:22:46 -08:00