mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT][Autocast] document that scripted autocast context cannot disable eager-enabled autocast (#81747)
JIT autocast mode settings that are modified by a scripted autocast context are separate from the eager-mode autocast settings that affect the dispatcher. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81747 Approved by: https://github.com/jjsjann123, https://github.com/cpuhrsch
This commit is contained in:
parent
a7c1f74426
commit
d3f27f9312
|
|
@ -17,6 +17,7 @@
|
||||||
- [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast)
|
- [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast)
|
||||||
- [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced)
|
- [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced)
|
||||||
- [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script)
|
- [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script)
|
||||||
|
- [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast)
|
||||||
- [References](#references)
|
- [References](#references)
|
||||||
|
|
||||||
<!-- /code_chunk_output -->
|
<!-- /code_chunk_output -->
|
||||||
|
|
@ -169,6 +170,25 @@ def traced(a, b):
|
||||||
torch.jit.trace(traced, (x, y))
|
torch.jit.trace(traced, (x, y))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Disabling eager autocast with scripted autocast
|
||||||
|
|
||||||
|
If eager-mode autocast is enabled and we try to disable autocasting from
|
||||||
|
within a scripted function, autocasting will still occur.
|
||||||
|
|
||||||
|
```python
|
||||||
|
@torch.jit.script
|
||||||
|
def fn(a, b):
|
||||||
|
with autocast(enabled=False):
|
||||||
|
return torch.mm(a, b)
|
||||||
|
|
||||||
|
x = torch.rand((2, 2), device='cuda', dtype=torch.float)
|
||||||
|
y = torch.rand((2, 2), device='cuda', dtype=torch.float)
|
||||||
|
|
||||||
|
# this will print half-precision dtype
|
||||||
|
with autocast(enabled=True):
|
||||||
|
print(fn(x, y).dtype)
|
||||||
|
```
|
||||||
|
|
||||||
## References
|
## References
|
||||||
|
|
||||||
- [torch.cuda.amp Package][1]
|
- [torch.cuda.amp Package][1]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user