diff --git a/torch/csrc/jit/JIT-AUTOCAST.md b/torch/csrc/jit/JIT-AUTOCAST.md index 00e66b77b14..c8e7ffac5d4 100644 --- a/torch/csrc/jit/JIT-AUTOCAST.md +++ b/torch/csrc/jit/JIT-AUTOCAST.md @@ -17,6 +17,7 @@ - [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast) - [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced) - [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script) + - [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast) - [References](#references) @@ -169,6 +170,25 @@ def traced(a, b): torch.jit.trace(traced, (x, y)) ``` +#### Disabling eager autocast with scripted autocast + +If eager-mode autocast is enabled and we try to disable autocasting from +within a scripted function, autocasting will still occur. + +```python +@torch.jit.script +def fn(a, b): + with autocast(enabled=False): + return torch.mm(a, b) + +x = torch.rand((2, 2), device='cuda', dtype=torch.float) +y = torch.rand((2, 2), device='cuda', dtype=torch.float) + +# this will print half-precision dtype +with autocast(enabled=True): + print(fn(x, y).dtype) +``` + ## References - [torch.cuda.amp Package][1]