[JIT][Autocast] document that scripted autocast context cannot disable eager-enabled autocast (#81747)

JIT autocast mode settings that are modified by a scripted autocast
context are separate from the eager-mode autocast settings that affect
the dispatcher.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81747
Approved by: https://github.com/jjsjann123, https://github.com/cpuhrsch
This commit is contained in:
David Berard 2022-07-19 17:46:10 -07:00 committed by PyTorch MergeBot
parent a7c1f74426
commit d3f27f9312

View File

@ -17,6 +17,7 @@
- [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast)
- [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced)
- [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script)
- [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast)
- [References](#references)
<!-- /code_chunk_output -->
@ -169,6 +170,25 @@ def traced(a, b):
torch.jit.trace(traced, (x, y))
```
#### Disabling eager autocast with scripted autocast
If eager-mode autocast is enabled and we try to disable autocasting from
within a scripted function, autocasting will still occur.
```python
@torch.jit.script
def fn(a, b):
with autocast(enabled=False):
return torch.mm(a, b)
x = torch.rand((2, 2), device='cuda', dtype=torch.float)
y = torch.rand((2, 2), device='cuda', dtype=torch.float)
# this will print half-precision dtype
with autocast(enabled=True):
print(fn(x, y).dtype)
```
## References
- [torch.cuda.amp Package][1]