pytorch/test/test_jit_autocast.py
Philip Meier bc73affdad prepare removal of deprecated functionality in torch.testing (#87969)
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._

---

Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc0ac, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87969
Approved by: https://github.com/mruberry
2022-11-02 14:04:48 +00:00

946 lines
36 KiB
Python

# Owner(s): ["oncall: jit"]
import torch
from torch.cuda.amp import autocast
from typing import Optional, Tuple
import unittest
from test_jit import JitTestCase
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
from torch.testing import FileCheck
from jit.test_models import MnistNet
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
class TestAutocast(JitTestCase):
def setUp(self):
# common input tensors
if TEST_CUDA:
self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
self.old_value = torch._C._jit_set_autocast_mode(True)
super().setUp()
def tearDown(self):
torch._C._jit_set_autocast_mode(self.old_value)
super().tearDown()
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal(self):
@torch.jit.script
def fn(a, b):
with autocast():
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.float16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
def test_linear_bf16(self):
@torch.jit.script
def fn(a, b):
with autocast(dtype=torch.bfloat16):
x = torch.mm(a, b)
y = torch.sum(x)
return x, y
x, y = fn(self.a_fp32, self.b_fp32)
self.assertEqual(x.dtype, torch.bfloat16)
self.assertEqual(y.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal_cpu(self):
@torch.jit.script
def fn(a, b):
with autocast():
return torch.mm(a, b)
result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
self.assertEqual(result.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_minimal_off(self):
@torch.jit.script
def fn(a, b):
with autocast(enabled=False):
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_runtime_autocast_state(self):
@torch.jit.script
def fn(a, b, use_amp: bool):
with autocast(enabled=use_amp):
return torch.mm(a, b)
# runtime values for autocast enable argument are not supported
with self.assertRaises(RuntimeError):
fn(self.a_fp32, self.b_fp32, True)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_runtime_autocast_state_expr(self):
@torch.jit.script
def fn(a, b):
with autocast(enabled=True if a[0][0] > 0.5 else False):
return torch.mm(a, b)
# runtime values for autocast enable argument are not supported
with self.assertRaises(RuntimeError):
fn(self.a_fp32, self.b_fp32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_explicit_casts(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast():
e = torch.mm(a.double(), b.double()).float()
f = torch.mm(c, d).double()
g = torch.mm(c.double(), f)
return e, f, g
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float32)
self.assertEqual(f.dtype, torch.float64)
self.assertEqual(g.dtype, torch.float64)
# multiple uses of the same input value
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_duplicate_inputs(self):
@torch.jit.script
def fn(a, b):
with autocast():
e = torch.mm(a, a)
f = torch.mm(e, e)
return e, f
e, f = fn(self.a_fp32, self.b_fp32)
self.assertEqual(e.dtype, torch.float16)
self.assertEqual(f.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_fp32_policy(self):
@torch.jit.script
def fn(a):
with autocast(enabled=True):
return torch.log(a)
result = fn(self.a_fp16)
self.assertEqual(result.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_fp32_policy_with_fp64(self):
@torch.jit.script
def fn(a):
with autocast(enabled=True):
return torch.log(a)
# fp32 policy should not narrow fp64 to fp32!
result = fn(self.a_fp32.double())
self.assertEqual(result.dtype, torch.float64)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_promote_policy(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast():
e = torch.mm(a, b)
f = torch.addcmul(e, c, d, value=0.1)
return e, f
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float16)
self.assertEqual(f.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_promote_policy_fp64(self):
@torch.jit.script
def fn(a, b):
with autocast(enabled=True):
return torch.addcmul(a, a, b, value=0.1)
result = fn(self.a_fp32.double(), self.b_fp32.double())
self.assertEqual(result.dtype, torch.float64)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_fp32_set_opt_dtype_policy(self):
@torch.jit.script
def fn(a, b, c, d, dtype: Optional[int]):
with autocast(enabled=True):
x = torch.softmax(a, 0)
y = torch.softmax(b, 0, None)
z = torch.softmax(c, 0, torch.float64)
w = torch.softmax(d, 0, dtype)
return x, y, z, w
x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
self.assertEqual(x.dtype, torch.float32)
self.assertEqual(y.dtype, torch.float32)
self.assertEqual(z.dtype, torch.float64)
self.assertEqual(w.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_fp32_set_opt_dtype_policy_fp64(self):
@torch.jit.script
def fn(a, b, c, d, dtype: Optional[int]):
with autocast(enabled=True):
x = torch.softmax(a, 0)
y = torch.softmax(b, 0, None)
z = torch.softmax(c, 0, torch.float64)
w = torch.softmax(d, 0, dtype)
return x, y, z, w
x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
self.assertEqual(x.dtype, torch.float64)
self.assertEqual(y.dtype, torch.float64)
self.assertEqual(z.dtype, torch.float64)
self.assertEqual(w.dtype, torch.float64)
@unittest.skipIf(True, "broken due to lack of type propagation")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_control_flow(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast():
if a[0][0] > 0.5:
e = torch.mm(a, b)
x = 1
else:
e = torch.mm(c, d)
x = 2
f = torch.mm(d, e) * x
return e, f
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float16)
self.assertEqual(f.dtype, torch.float16)
# this works find in regular Python, but it creates a delicate
# situation in TorchScript where the types are not consistent across
# the then/else branches
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_divergent_types(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast():
if a[0][0] > 0.5:
e = torch.mm(a, b)
f = torch.mm(a, b).float()
else:
e = torch.mm(c, d).float()
f = torch.mm(a, b)
return torch.mm(e.float(), f.float())
result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(result.dtype, torch.float32)
# another, more complex case of divergent types
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_divergent_autocast(self):
@torch.jit.script
def fn(a, b, c, d):
autocast_on = autocast(enabled=True)
autocast_off = autocast(enabled=False)
if a[0][0] > 0.5:
with autocast_on:
e = torch.mm(a, b)
else:
with autocast_off:
e = torch.mm(c, d)
return torch.mm(e, e)
fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_conditional_autocast(self):
@torch.jit.script
def fn(a, b):
autocast_on = autocast(enabled=True)
autocast_off = autocast(enabled=False)
with autocast_on if a[0][0] > 0.5 else autocast_off:
return torch.mm(a, b)
# conditional autocast expressions are not supported
with self.assertRaises(RuntimeError):
fn(self.a_fp32, self.b_fp32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_nested_autocast(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast(enabled=False):
e = torch.mm(a, b)
with autocast(enabled=True):
f = torch.mm(e, c)
with autocast(enabled=False):
g = torch.mm(e, d)
return e, f, g
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float32)
self.assertEqual(f.dtype, torch.float16)
self.assertEqual(g.dtype, torch.float32)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_implicitly_nested_autocast(self):
@torch.jit.script
def fn(a, b):
with autocast(enabled=False), autocast(enabled=True):
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_reused_autocast(self):
@torch.jit.script
def fn(a, b, c, d):
autocast_instance = autocast(enabled=True)
with autocast_instance:
e = torch.mm(a, b)
with autocast_instance:
e = torch.mm(c, d)
f = torch.mm(d, e)
g = torch.mm(e, f)
return e, f, g
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float16)
self.assertEqual(f.dtype, torch.float16)
self.assertEqual(g.dtype, torch.float16)
# TODO: fix and enable this test?
# (we could technically fix this, but is it really worth it?)
@unittest.skipIf(True, "unsuported autocast syntax")
def test_reused_autocast_expr(self):
@torch.jit.script
def fn(a, b, c, d):
with autocast(enabled=True) as autocast_instance:
e = torch.mm(a, b)
with autocast_instance:
e = torch.mm(c, d)
f = torch.mm(d, e)
g = torch.mm(e, f)
return e, f, g
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
self.assertEqual(e.dtype, torch.float16)
self.assertEqual(f.dtype, torch.float16)
self.assertEqual(g.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_callees(self):
def helper(a, b):
return torch.mm(a, b)
@torch.jit.script
def fn(a, b):
with autocast(enabled=True):
tmp = helper(a, b)
tmp = helper(tmp, tmp)
tmp = helper(tmp, tmp)
tmp = helper(tmp, tmp)
return helper(tmp, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_callees_with_autocast_on(self):
def helper(a, b):
with autocast(enabled=True):
return torch.mm(a, b)
@torch.jit.script
def fn(a, b):
with autocast(enabled=False):
return helper(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_callees_with_autocast_off(self):
def helper(a, b):
with autocast(enabled=False):
return torch.mm(a, b)
@torch.jit.script
def fn(a, b):
with autocast(enabled=True):
return helper(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float32)
# scripting inside eager autocast
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_eager_and_script(self):
@torch.jit.script
def fn(a, b):
return torch.mm(a, b)
for i in range(8):
use_autocast = (i % 2 == 0)
expected_dtype = torch.float16 if use_autocast else torch.float32
with autocast(enabled=use_autocast):
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, expected_dtype)
# traced inside scripting
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_and_tracing(self):
def helper(a, b):
return torch.mm(a, b)
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
@torch.jit.script
def fn(a, b):
with autocast(enabled=True):
return traced(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
# traced with autocast inside scripting
@unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_and_tracing_with_autocast(self):
def helper(a, b):
with autocast(enabled=False):
return torch.mm(a, b) * 2.0
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
@torch.jit.script
def fn(a, b):
with autocast(enabled=True):
return traced(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float32)
# scripted called from traced
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_tracing_and_script(self):
@torch.jit.script
def fn(a, b):
with autocast():
return torch.mm(a, b)
def traced(a, b):
return fn(a, b)
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
result = traced(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
# scripted called from traced with autocast
@unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_tracing_with_autocast_and_script(self):
@torch.jit.script
def fn(a, b):
return torch.mm(a, b)
def traced(a, b):
with autocast(enabled=True):
return fn(a, b)
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
result = traced(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_module(self):
class TestModule(torch.nn.Module):
def __init__(self, N, M):
super().__init__()
self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
self.linear = torch.nn.Linear(N, M).float()
def forward(self, input):
with autocast(enabled=True):
output = self.weight.mv(input)
output = self.linear(output)
return output
scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
input = torch.rand(3, dtype=torch.float32, device='cuda')
result = scripted_module(input)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(True, "autocast decorators not supported")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_decorator(self):
@torch.jit.script
@autocast(enabled=True)
def fn(a, b):
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
# this is equivalent to running scripted functions inside autocast)
# (see also test_eager_and_script)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_decorator_outside_jit(self):
@autocast(enabled=True)
@torch.jit.script
def fn(a, b):
return torch.mm(a, b)
result = fn(self.a_fp32, self.b_fp32)
self.assertEqual(result.dtype, torch.float16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_inplace(self):
@torch.jit.script
def fn(a, b, c):
with autocast(enabled=True):
x = torch.addmm(a, b, c)
y = torch.addmm(a, b, c, out=a)
z = a.addmm_(b, c)
return x, y, z
x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
self.assertEqual(x.dtype, torch.float16)
self.assertEqual(y.dtype, torch.float32)
self.assertEqual(z.dtype, torch.float32)
def _test_autocast(self, func, cast_op, *args):
jit_func = torch.jit.script(func)
o = func(*args)
jit_o = jit_func(*args)
if cast_op is not None:
FileCheck().check(cast_op).run(jit_func.graph_for(*args))
for o0, o1 in zip(o, jit_o):
self.assertEqual(o0.dtype, o1.dtype)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_api(self):
def t_autocast_cpu(x, y):
with torch.autocast("cpu", dtype=torch.bfloat16):
return torch.mm(x, y)
def t_autocast_cuda(x, y):
with torch.autocast("cuda", dtype=torch.half):
return torch.mm(x, y)
def t_cuda_amp_autocast(x, y):
with torch.cuda.amp.autocast():
return torch.mm(x, y)
def t_cpu_amp_autocast(x, y):
with torch.cpu.amp.autocast():
return torch.mm(x, y)
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
@unittest.skipIf(True, "we need to provide dtype argument at this moment")
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_api_not_supported(self):
def t_autocast_cpu(x, y):
# no dtype provided is not currently supported
with torch.autocast("cpu"):
return torch.mm(x, y)
def t_autocast_cuda(x, y):
# no dtype provided is not currently supported
with torch.autocast("cuda"):
return torch.mm(x, y)
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_mixed_dtypes(self):
def t(cpu0, cpu1, cuda0, cuda1):
with torch.autocast("cpu", torch.bfloat16):
with torch.autocast("cuda", torch.float16):
cpu_o = torch.mm(cpu0, cpu1)
cuda_o = torch.mm(cuda0, cuda1)
return cpu_o, cuda_o
jit_t = torch.jit.script(t)
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_executor_under_autocast(self):
def t(cpu0, cpu1, cuda0, cuda1):
cpu_o = torch.mm(cpu0, cpu1)
cuda_o = torch.mm(cuda0, cuda1)
return cpu_o, cuda_o
jit_t = torch.jit.script(t)
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
with torch.autocast("cpu", torch.bfloat16):
with torch.autocast("cuda", torch.float16):
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
with torch.autocast("cpu", torch.bfloat16):
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
with torch.autocast("cuda", torch.float16):
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
# no cast op should be observed when executing outside autocast context
self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_autodiff(self):
def t(t0, t1):
o = torch.mm(t0, t1)
return o.relu()
jit_t = torch.jit.script(t)
t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
# run optimization
for i in range(5):
with torch.autocast("cuda", torch.float16):
jit_o = jit_t(t0, t1)
jit_o.sum().backward()
t0.grad = None
t1.grad = None
ref_t0 = t0.detach().requires_grad_()
ref_t1 = t1.detach().requires_grad_()
with torch.autocast("cuda", torch.float16):
o = t(ref_t0, ref_t1)
jit_o = jit_t(t0, t1)
jit_o.sum().backward()
o.sum().backward()
self.assertEqual(o, jit_o)
self.assertEqual(t0.grad, ref_t0.grad)
self.assertEqual(t1.grad, ref_t1.grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_call_method_under_autocast(self):
@torch.jit.interface
class Iface(torch.nn.Module):
def forward(self, x, y) -> torch.Tensor:
pass
class Impl(Iface):
def forward(self, x, y):
return torch.mm(x, y)
class Thing1(torch.nn.Module):
impl: Iface
def forward(self, x, y):
with torch.cuda.amp.autocast():
a = torch.mm(x, y)
b = self.impl.forward(a, x)
return b
scripted_impl = torch.jit.script(Impl())
thing1 = Thing1()
thing1.impl = scripted_impl
scripted_thing1 = torch.jit.script(thing1)
x = torch.rand([2, 2])
y = torch.rand([2, 2])
# make sure this doesn't throw an error
with torch.cuda.amp.autocast():
ans = scripted_thing1.forward(x, y)
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
# sanity check: this isn't supported currently when global autocasting
# isn't enabled
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_freeze_autocast_basic(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, x, y):
with torch.cuda.amp.autocast():
return torch.mm(x, y)
x = torch.rand((3, 4), dtype=torch.float).cuda()
y = torch.rand((4, 5), dtype=torch.float).cuda()
mod = TestModule().eval()
# sanity check
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
# make sure that the runtime pass doesn't duplicate autocast nodes
frozen_mod(x, y)
optimized_graph = frozen_mod.graph_for(x, y)
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_freeze_autocast_constants(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
def forward(self, y):
with torch.cuda.amp.autocast():
return torch.mm(self.x, y)
y = torch.rand((4, 5), dtype=torch.float).cuda()
mod = TestModule().eval()
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
# freezing should pre-cast the constant self.x to remove one autocast call
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
# the runtime autocasting pass will re-insert the second autocast call,
# but constant propagation will merge it with the constant that it's casting.
frozen_mod(y)
optimized_graph = frozen_mod.graph_for(y)
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
@unittest.skipIf(TEST_CUDA, "CPU-only test")
def test_jit_autocast_softmax_cpu(self):
def fn(x):
with torch.cpu.amp.autocast():
return torch.nn.functional.softmax(x, dim=0)
fn_s = torch.jit.script(fn)
x = torch.rand((2, 2), dtype=torch.bfloat16)
fn_s(x)
y = fn_s(x)
self.assertTrue(y.dtype == torch.bfloat16)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_autocast_softmax_gpu(self):
def fn(x):
with torch.cuda.amp.autocast():
return torch.nn.functional.softmax(x, dim=0)
fn_s = torch.jit.script(fn)
x = torch.rand((2, 2), dtype=torch.half).cuda()
fn_s(x)
y = fn_s(x)
self.assertTrue(y.dtype == torch.float)
def test_ignore_amp(self):
@torch.jit.script
def foo(x):
return torch.mm(x, x)
inp = torch.rand([10, 10], dtype=torch.float)
foo._set_ignore_amp(True)
with torch.cpu.amp.autocast():
foo(inp)
foo(inp)
g = torch.jit.last_executed_optimized_graph()
FileCheck().check_not("_autocast_to_reduced").run(g)
class convbn(torch.nn.Module):
def __init__(self, bias_enabled=True):
super(convbn, self).__init__()
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
self.bn = torch.nn.BatchNorm2d(64)
def forward(self, x):
return self.bn(self.conv(x))
class TestJitTraceAutocast(JitTestCase):
def setUp(self):
super(TestJitTraceAutocast, self).setUp()
self.previous_default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.float32)
self.models = [MnistNet(),
convbn(bias_enabled=True),
convbn(bias_enabled=False)]
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu'),
torch.randn(32, 3, 224, 224, device='cpu')]
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
def tearDown(self):
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
torch.set_default_dtype(self.previous_default_dtype)
super(TestJitTraceAutocast, self).tearDown()
def test_generate_autocast_jit_trace_model(self):
def test_generate_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
for i in range(self.models.__len__()):
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nchw_autocast_jit_trace_model(self):
def test_nchw_autocast_jit_trace_model(model, x):
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x)
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone())
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone())
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_nhwc_autocast_jit_trace_model(self):
def test_nhwc_autocast_jit_trace_model(model, x):
model = model.to(memory_format=torch.channels_last)
model.eval()
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
traced_model = torch.jit.freeze(traced_model)
with torch.no_grad():
y = traced_model(x.clone().to(memory_format=torch.channels_last))
with torch.cpu.amp.autocast(), torch.no_grad():
y2 = model(x.clone().to(memory_format=torch.channels_last))
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
for i in range(self.models.__len__()):
if self.inputs[i].size().__len__() == 5:
# NHWC 3D case not support yet
continue
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
def test_cat_promote(self):
class TestModel(torch.nn.Module):
def __init__(self):
super(TestModel, self).__init__()
def forward(self, a, b):
return torch.cat([a, b], 0)
with torch.jit.fuser("none"):
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
# To avoid the fusion group from TE, we will disable the fuser here.
for jit_freeze_or_not in [False, True]:
test_model = TestModel().eval()
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
a = torch.rand(24, 128, 128)
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
c = test_model(a, b)
traced = torch.jit.trace(test_model, (a, b))
if jit_freeze_or_not:
traced = torch.jit.freeze(traced)
for _ in range(3):
c2 = traced(a, b)
self.assertTrue(c.dtype, torch.float32)
self.assertTrue(c2.dtype, torch.float32)
traced_graph = traced.graph_for(a, b)
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
def test_script_autocast_cpu(self):
def fn(x):
if torch.is_autocast_cpu_enabled():
return x.relu()
else:
return x.sin()
fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast():
self.assertEqual(fn_s(x), fn(x))
with torch.cpu.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_script_autocast_cuda(self):
def fn(x):
if torch.is_autocast_enabled():
return x.relu()
else:
return x.sin()
fn_s = torch.jit.script(fn)
x = torch.rand((4, 4)) - 0.5
with torch.cpu.amp.autocast():
self.assertEqual(fn_s(x), fn(x))
with torch.cuda.amp.autocast(enabled=True):
self.assertEqual(fn_s(x), fn(x))
self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
def test_scripted_aliasing(self):
# torch.is_autocast_enabled should not be able to move inside of the autocast context.
def fn(x):
if torch.is_autocast_enabled():
y = True
else:
y = False
with torch.cuda.amp.autocast(enabled=True):
z = x.relu()
return y, z
fn_s = torch.jit.script(fn)
graph = fn_s.graph
aliasdb = graph.alias_db()
is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
enter_nodes = graph.findAllNodes("prim::Enter")
self.assertEqual(len(is_enabled_nodes), 1)
self.assertEqual(len(enter_nodes), 1)
self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
def test_script_autocast_enable_and_check(self):
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
b1 = torch.is_autocast_cpu_enabled()
v1 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=True):
b2 = torch.is_autocast_cpu_enabled()
v2 = torch.mm(x, y)
with torch.cpu.amp.autocast(enabled=False):
b3 = torch.is_autocast_cpu_enabled()
v3 = torch.mm(x, y)
return (v1, b1, v2, b2, v3, b3)
# bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
def check_fn_results(arr):
[v1, b1, v2, b2, v3, b3] = arr
self.assertTrue((v1.dtype == torch.float) != b1)
self.assertTrue((v2.dtype == torch.float) != b2)
self.assertTrue((v3.dtype == torch.float) != b3)
x = torch.rand((2, 2), dtype=torch.float)
y = torch.rand((2, 2), dtype=torch.float)
fn_s = torch.jit.script(fn)
with torch.cpu.amp.autocast(enabled=False):
check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y))
with torch.cpu.amp.autocast(enabled=True):
check_fn_results(fn(x, y))
check_fn_results(fn_s(x, y))
if __name__ == "__main__":
run_tests()