mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
_Redo of #86586 with all BC breaking changes granularly placed into separate commits._
---
Per title. Deprecation happened on Feb 25, 2022 in c6f1bbc0ac, which made it into the 1.12 release. Since it is now 245 days later and the next release will be 1.14, the removals later in the stack comply with the [BC policy](https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#minimizing-the-disruption-of-bc-breaking-changes).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87969
Approved by: https://github.com/mruberry
946 lines
36 KiB
Python
946 lines
36 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import torch
|
|
from torch.cuda.amp import autocast
|
|
from typing import Optional, Tuple
|
|
|
|
import unittest
|
|
from test_jit import JitTestCase
|
|
from torch.testing._internal.common_cuda import TEST_CUDA
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing import FileCheck
|
|
from jit.test_models import MnistNet
|
|
|
|
TEST_BFLOAT16 = TEST_CUDA and torch.cuda.is_bf16_supported()
|
|
|
|
class TestAutocast(JitTestCase):
|
|
def setUp(self):
|
|
# common input tensors
|
|
if TEST_CUDA:
|
|
self.a_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
|
|
self.b_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
|
|
self.c_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
|
|
self.d_fp16 = torch.rand((2, 2), dtype=torch.float16, device='cuda')
|
|
self.a_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
|
|
self.b_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
|
|
self.c_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
|
|
self.d_fp32 = torch.rand((2, 2), dtype=torch.float32, device='cuda')
|
|
self.old_value = torch._C._jit_set_autocast_mode(True)
|
|
super().setUp()
|
|
|
|
def tearDown(self):
|
|
torch._C._jit_set_autocast_mode(self.old_value)
|
|
super().tearDown()
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_minimal(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast():
|
|
x = torch.mm(a, b)
|
|
y = torch.sum(x)
|
|
return x, y
|
|
x, y = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(x.dtype, torch.float16)
|
|
self.assertEqual(y.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA or not TEST_BFLOAT16, "No cuda bfloat16 support")
|
|
def test_linear_bf16(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(dtype=torch.bfloat16):
|
|
x = torch.mm(a, b)
|
|
y = torch.sum(x)
|
|
return x, y
|
|
x, y = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(x.dtype, torch.bfloat16)
|
|
self.assertEqual(y.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_minimal_cpu(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast():
|
|
return torch.mm(a, b)
|
|
result = fn(self.a_fp32.to('cpu'), self.b_fp32.to('cpu'))
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_minimal_off(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=False):
|
|
return torch.mm(a, b)
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_runtime_autocast_state(self):
|
|
@torch.jit.script
|
|
def fn(a, b, use_amp: bool):
|
|
with autocast(enabled=use_amp):
|
|
return torch.mm(a, b)
|
|
# runtime values for autocast enable argument are not supported
|
|
with self.assertRaises(RuntimeError):
|
|
fn(self.a_fp32, self.b_fp32, True)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_runtime_autocast_state_expr(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True if a[0][0] > 0.5 else False):
|
|
return torch.mm(a, b)
|
|
# runtime values for autocast enable argument are not supported
|
|
with self.assertRaises(RuntimeError):
|
|
fn(self.a_fp32, self.b_fp32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_explicit_casts(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast():
|
|
e = torch.mm(a.double(), b.double()).float()
|
|
f = torch.mm(c, d).double()
|
|
g = torch.mm(c.double(), f)
|
|
return e, f, g
|
|
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float32)
|
|
self.assertEqual(f.dtype, torch.float64)
|
|
self.assertEqual(g.dtype, torch.float64)
|
|
|
|
# multiple uses of the same input value
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_duplicate_inputs(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast():
|
|
e = torch.mm(a, a)
|
|
f = torch.mm(e, e)
|
|
return e, f
|
|
e, f = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(e.dtype, torch.float16)
|
|
self.assertEqual(f.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_fp32_policy(self):
|
|
@torch.jit.script
|
|
def fn(a):
|
|
with autocast(enabled=True):
|
|
return torch.log(a)
|
|
result = fn(self.a_fp16)
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_fp32_policy_with_fp64(self):
|
|
@torch.jit.script
|
|
def fn(a):
|
|
with autocast(enabled=True):
|
|
return torch.log(a)
|
|
# fp32 policy should not narrow fp64 to fp32!
|
|
result = fn(self.a_fp32.double())
|
|
self.assertEqual(result.dtype, torch.float64)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_promote_policy(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast():
|
|
e = torch.mm(a, b)
|
|
f = torch.addcmul(e, c, d, value=0.1)
|
|
return e, f
|
|
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float16)
|
|
self.assertEqual(f.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_promote_policy_fp64(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True):
|
|
return torch.addcmul(a, a, b, value=0.1)
|
|
result = fn(self.a_fp32.double(), self.b_fp32.double())
|
|
self.assertEqual(result.dtype, torch.float64)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_fp32_set_opt_dtype_policy(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d, dtype: Optional[int]):
|
|
with autocast(enabled=True):
|
|
x = torch.softmax(a, 0)
|
|
y = torch.softmax(b, 0, None)
|
|
z = torch.softmax(c, 0, torch.float64)
|
|
w = torch.softmax(d, 0, dtype)
|
|
return x, y, z, w
|
|
x, y, z, w = fn(self.a_fp16, self.b_fp16, self.c_fp16, self.d_fp16, None)
|
|
self.assertEqual(x.dtype, torch.float32)
|
|
self.assertEqual(y.dtype, torch.float32)
|
|
self.assertEqual(z.dtype, torch.float64)
|
|
self.assertEqual(w.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_fp32_set_opt_dtype_policy_fp64(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d, dtype: Optional[int]):
|
|
with autocast(enabled=True):
|
|
x = torch.softmax(a, 0)
|
|
y = torch.softmax(b, 0, None)
|
|
z = torch.softmax(c, 0, torch.float64)
|
|
w = torch.softmax(d, 0, dtype)
|
|
return x, y, z, w
|
|
x, y, z, w = fn(self.a_fp32.double(), self.b_fp32.double(), self.c_fp32.double(), self.d_fp32.double(), None)
|
|
self.assertEqual(x.dtype, torch.float64)
|
|
self.assertEqual(y.dtype, torch.float64)
|
|
self.assertEqual(z.dtype, torch.float64)
|
|
self.assertEqual(w.dtype, torch.float64)
|
|
|
|
@unittest.skipIf(True, "broken due to lack of type propagation")
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_control_flow(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast():
|
|
if a[0][0] > 0.5:
|
|
e = torch.mm(a, b)
|
|
x = 1
|
|
else:
|
|
e = torch.mm(c, d)
|
|
x = 2
|
|
f = torch.mm(d, e) * x
|
|
return e, f
|
|
e, f = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float16)
|
|
self.assertEqual(f.dtype, torch.float16)
|
|
|
|
# this works find in regular Python, but it creates a delicate
|
|
# situation in TorchScript where the types are not consistent across
|
|
# the then/else branches
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_divergent_types(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast():
|
|
if a[0][0] > 0.5:
|
|
e = torch.mm(a, b)
|
|
f = torch.mm(a, b).float()
|
|
else:
|
|
e = torch.mm(c, d).float()
|
|
f = torch.mm(a, b)
|
|
return torch.mm(e.float(), f.float())
|
|
result = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
# another, more complex case of divergent types
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_divergent_autocast(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
autocast_on = autocast(enabled=True)
|
|
autocast_off = autocast(enabled=False)
|
|
if a[0][0] > 0.5:
|
|
with autocast_on:
|
|
e = torch.mm(a, b)
|
|
else:
|
|
with autocast_off:
|
|
e = torch.mm(c, d)
|
|
return torch.mm(e, e)
|
|
fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_conditional_autocast(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
autocast_on = autocast(enabled=True)
|
|
autocast_off = autocast(enabled=False)
|
|
with autocast_on if a[0][0] > 0.5 else autocast_off:
|
|
return torch.mm(a, b)
|
|
# conditional autocast expressions are not supported
|
|
with self.assertRaises(RuntimeError):
|
|
fn(self.a_fp32, self.b_fp32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_nested_autocast(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast(enabled=False):
|
|
e = torch.mm(a, b)
|
|
with autocast(enabled=True):
|
|
f = torch.mm(e, c)
|
|
with autocast(enabled=False):
|
|
g = torch.mm(e, d)
|
|
return e, f, g
|
|
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float32)
|
|
self.assertEqual(f.dtype, torch.float16)
|
|
self.assertEqual(g.dtype, torch.float32)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_implicitly_nested_autocast(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=False), autocast(enabled=True):
|
|
return torch.mm(a, b)
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_reused_autocast(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
autocast_instance = autocast(enabled=True)
|
|
with autocast_instance:
|
|
e = torch.mm(a, b)
|
|
with autocast_instance:
|
|
e = torch.mm(c, d)
|
|
f = torch.mm(d, e)
|
|
g = torch.mm(e, f)
|
|
return e, f, g
|
|
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float16)
|
|
self.assertEqual(f.dtype, torch.float16)
|
|
self.assertEqual(g.dtype, torch.float16)
|
|
|
|
# TODO: fix and enable this test?
|
|
# (we could technically fix this, but is it really worth it?)
|
|
@unittest.skipIf(True, "unsuported autocast syntax")
|
|
def test_reused_autocast_expr(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c, d):
|
|
with autocast(enabled=True) as autocast_instance:
|
|
e = torch.mm(a, b)
|
|
with autocast_instance:
|
|
e = torch.mm(c, d)
|
|
f = torch.mm(d, e)
|
|
g = torch.mm(e, f)
|
|
return e, f, g
|
|
e, f, g = fn(self.a_fp32, self.b_fp32, self.c_fp32, self.d_fp32)
|
|
self.assertEqual(e.dtype, torch.float16)
|
|
self.assertEqual(f.dtype, torch.float16)
|
|
self.assertEqual(g.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_callees(self):
|
|
def helper(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True):
|
|
tmp = helper(a, b)
|
|
tmp = helper(tmp, tmp)
|
|
tmp = helper(tmp, tmp)
|
|
tmp = helper(tmp, tmp)
|
|
return helper(tmp, b)
|
|
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_callees_with_autocast_on(self):
|
|
def helper(a, b):
|
|
with autocast(enabled=True):
|
|
return torch.mm(a, b)
|
|
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=False):
|
|
return helper(a, b)
|
|
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_callees_with_autocast_off(self):
|
|
def helper(a, b):
|
|
with autocast(enabled=False):
|
|
return torch.mm(a, b)
|
|
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True):
|
|
return helper(a, b)
|
|
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
# scripting inside eager autocast
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_eager_and_script(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
return torch.mm(a, b)
|
|
for i in range(8):
|
|
use_autocast = (i % 2 == 0)
|
|
expected_dtype = torch.float16 if use_autocast else torch.float32
|
|
with autocast(enabled=use_autocast):
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, expected_dtype)
|
|
|
|
# traced inside scripting
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_script_and_tracing(self):
|
|
def helper(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
|
|
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True):
|
|
return traced(a, b)
|
|
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
# traced with autocast inside scripting
|
|
@unittest.skipIf(True, "autocast(False) is ignored inside traced functions")
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_script_and_tracing_with_autocast(self):
|
|
def helper(a, b):
|
|
with autocast(enabled=False):
|
|
return torch.mm(a, b) * 2.0
|
|
|
|
traced = torch.jit.trace(helper, (self.a_fp32, self.a_fp32))
|
|
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast(enabled=True):
|
|
return traced(a, b)
|
|
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float32)
|
|
|
|
# scripted called from traced
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_tracing_and_script(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
with autocast():
|
|
return torch.mm(a, b)
|
|
|
|
def traced(a, b):
|
|
return fn(a, b)
|
|
|
|
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
|
|
result = traced(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
# scripted called from traced with autocast
|
|
@unittest.skipIf(True, "scripted called from traced TorchScript is not yet working")
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_tracing_with_autocast_and_script(self):
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
def traced(a, b):
|
|
with autocast(enabled=True):
|
|
return fn(a, b)
|
|
|
|
traced = torch.jit.trace(traced, (self.a_fp32, self.b_fp32))
|
|
result = traced(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_script_module(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self, N, M):
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand((N, M), dtype=torch.float32))
|
|
self.linear = torch.nn.Linear(N, M).float()
|
|
|
|
def forward(self, input):
|
|
with autocast(enabled=True):
|
|
output = self.weight.mv(input)
|
|
output = self.linear(output)
|
|
return output
|
|
|
|
scripted_module = torch.jit.script(TestModule(2, 3)).cuda()
|
|
input = torch.rand(3, dtype=torch.float32, device='cuda')
|
|
result = scripted_module(input)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(True, "autocast decorators not supported")
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_decorator(self):
|
|
@torch.jit.script
|
|
@autocast(enabled=True)
|
|
def fn(a, b):
|
|
return torch.mm(a, b)
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
# this is equivalent to running scripted functions inside autocast)
|
|
# (see also test_eager_and_script)
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_decorator_outside_jit(self):
|
|
@autocast(enabled=True)
|
|
@torch.jit.script
|
|
def fn(a, b):
|
|
return torch.mm(a, b)
|
|
result = fn(self.a_fp32, self.b_fp32)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_inplace(self):
|
|
@torch.jit.script
|
|
def fn(a, b, c):
|
|
with autocast(enabled=True):
|
|
x = torch.addmm(a, b, c)
|
|
y = torch.addmm(a, b, c, out=a)
|
|
z = a.addmm_(b, c)
|
|
return x, y, z
|
|
x, y, z = fn(self.a_fp32, self.b_fp32, self.c_fp32)
|
|
self.assertEqual(x.dtype, torch.float16)
|
|
self.assertEqual(y.dtype, torch.float32)
|
|
self.assertEqual(z.dtype, torch.float32)
|
|
|
|
def _test_autocast(self, func, cast_op, *args):
|
|
jit_func = torch.jit.script(func)
|
|
o = func(*args)
|
|
jit_o = jit_func(*args)
|
|
if cast_op is not None:
|
|
FileCheck().check(cast_op).run(jit_func.graph_for(*args))
|
|
for o0, o1 in zip(o, jit_o):
|
|
self.assertEqual(o0.dtype, o1.dtype)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_api(self):
|
|
|
|
def t_autocast_cpu(x, y):
|
|
with torch.autocast("cpu", dtype=torch.bfloat16):
|
|
return torch.mm(x, y)
|
|
|
|
def t_autocast_cuda(x, y):
|
|
with torch.autocast("cuda", dtype=torch.half):
|
|
return torch.mm(x, y)
|
|
|
|
def t_cuda_amp_autocast(x, y):
|
|
with torch.cuda.amp.autocast():
|
|
return torch.mm(x, y)
|
|
|
|
def t_cpu_amp_autocast(x, y):
|
|
with torch.cpu.amp.autocast():
|
|
return torch.mm(x, y)
|
|
|
|
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
|
|
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
|
|
self._test_autocast(t_cuda_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
|
|
self._test_autocast(t_cpu_amp_autocast, "aten::_autocast_to_reduced_precision", x, y)
|
|
|
|
@unittest.skipIf(True, "we need to provide dtype argument at this moment")
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_api_not_supported(self):
|
|
|
|
def t_autocast_cpu(x, y):
|
|
# no dtype provided is not currently supported
|
|
with torch.autocast("cpu"):
|
|
return torch.mm(x, y)
|
|
|
|
def t_autocast_cuda(x, y):
|
|
# no dtype provided is not currently supported
|
|
with torch.autocast("cuda"):
|
|
return torch.mm(x, y)
|
|
|
|
x = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
y = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
self._test_autocast(t_autocast_cpu, "aten::_autocast_to_reduced_precision", x, y)
|
|
self._test_autocast(t_autocast_cuda, "aten::_autocast_to_reduced_precision", x, y)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_mixed_dtypes(self):
|
|
|
|
def t(cpu0, cpu1, cuda0, cuda1):
|
|
with torch.autocast("cpu", torch.bfloat16):
|
|
with torch.autocast("cuda", torch.float16):
|
|
cpu_o = torch.mm(cpu0, cpu1)
|
|
cuda_o = torch.mm(cuda0, cuda1)
|
|
return cpu_o, cuda_o
|
|
|
|
jit_t = torch.jit.script(t)
|
|
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
|
|
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
|
|
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_jit_executor_under_autocast(self):
|
|
|
|
def t(cpu0, cpu1, cuda0, cuda1):
|
|
cpu_o = torch.mm(cpu0, cpu1)
|
|
cuda_o = torch.mm(cuda0, cuda1)
|
|
return cpu_o, cuda_o
|
|
|
|
jit_t = torch.jit.script(t)
|
|
cpu0 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
|
|
cpu1 = torch.randn(5, 5, device="cpu", dtype=torch.float32)
|
|
cuda0 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
cuda1 = torch.randn(5, 5, device="cuda", dtype=torch.float32)
|
|
|
|
with torch.autocast("cpu", torch.bfloat16):
|
|
with torch.autocast("cuda", torch.float16):
|
|
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
|
|
|
|
with torch.autocast("cpu", torch.bfloat16):
|
|
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
|
|
|
|
with torch.autocast("cuda", torch.float16):
|
|
self._test_autocast(t, "aten::_autocast_to_reduced_precision", cpu0, cpu1, cuda0, cuda1)
|
|
|
|
# no cast op should be observed when executing outside autocast context
|
|
self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_autocast_autodiff(self):
|
|
def t(t0, t1):
|
|
o = torch.mm(t0, t1)
|
|
return o.relu()
|
|
|
|
jit_t = torch.jit.script(t)
|
|
t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
|
|
t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
|
|
|
|
# run optimization
|
|
for i in range(5):
|
|
with torch.autocast("cuda", torch.float16):
|
|
jit_o = jit_t(t0, t1)
|
|
jit_o.sum().backward()
|
|
|
|
t0.grad = None
|
|
t1.grad = None
|
|
ref_t0 = t0.detach().requires_grad_()
|
|
ref_t1 = t1.detach().requires_grad_()
|
|
|
|
with torch.autocast("cuda", torch.float16):
|
|
o = t(ref_t0, ref_t1)
|
|
jit_o = jit_t(t0, t1)
|
|
jit_o.sum().backward()
|
|
o.sum().backward()
|
|
self.assertEqual(o, jit_o)
|
|
self.assertEqual(t0.grad, ref_t0.grad)
|
|
self.assertEqual(t1.grad, ref_t1.grad)
|
|
self.assertEqual(o.dtype, jit_o.dtype)
|
|
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
|
|
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_jit_call_method_under_autocast(self):
|
|
@torch.jit.interface
|
|
class Iface(torch.nn.Module):
|
|
def forward(self, x, y) -> torch.Tensor:
|
|
pass
|
|
|
|
class Impl(Iface):
|
|
def forward(self, x, y):
|
|
return torch.mm(x, y)
|
|
|
|
class Thing1(torch.nn.Module):
|
|
impl: Iface
|
|
|
|
def forward(self, x, y):
|
|
with torch.cuda.amp.autocast():
|
|
a = torch.mm(x, y)
|
|
b = self.impl.forward(a, x)
|
|
return b
|
|
|
|
scripted_impl = torch.jit.script(Impl())
|
|
thing1 = Thing1()
|
|
thing1.impl = scripted_impl
|
|
scripted_thing1 = torch.jit.script(thing1)
|
|
x = torch.rand([2, 2])
|
|
y = torch.rand([2, 2])
|
|
|
|
# make sure this doesn't throw an error
|
|
with torch.cuda.amp.autocast():
|
|
ans = scripted_thing1.forward(x, y)
|
|
self.assertEqual(torch.mm(torch.mm(x, y), x), ans)
|
|
|
|
# sanity check: this isn't supported currently when global autocasting
|
|
# isn't enabled
|
|
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_jit_freeze_autocast_basic(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
|
|
def forward(self, x, y):
|
|
with torch.cuda.amp.autocast():
|
|
return torch.mm(x, y)
|
|
|
|
x = torch.rand((3, 4), dtype=torch.float).cuda()
|
|
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
|
|
|
mod = TestModule().eval()
|
|
|
|
# sanity check
|
|
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
|
|
|
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
|
|
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
|
|
|
|
# make sure that the runtime pass doesn't duplicate autocast nodes
|
|
frozen_mod(x, y)
|
|
optimized_graph = frozen_mod.graph_for(x, y)
|
|
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_jit_freeze_autocast_constants(self):
|
|
class TestModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
|
|
|
|
def forward(self, y):
|
|
with torch.cuda.amp.autocast():
|
|
return torch.mm(self.x, y)
|
|
|
|
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
|
mod = TestModule().eval()
|
|
|
|
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
|
|
# freezing should pre-cast the constant self.x to remove one autocast call
|
|
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
|
|
|
|
# the runtime autocasting pass will re-insert the second autocast call,
|
|
# but constant propagation will merge it with the constant that it's casting.
|
|
frozen_mod(y)
|
|
optimized_graph = frozen_mod.graph_for(y)
|
|
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
|
|
|
|
@unittest.skipIf(TEST_CUDA, "CPU-only test")
|
|
def test_jit_autocast_softmax_cpu(self):
|
|
def fn(x):
|
|
with torch.cpu.amp.autocast():
|
|
return torch.nn.functional.softmax(x, dim=0)
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
x = torch.rand((2, 2), dtype=torch.bfloat16)
|
|
fn_s(x)
|
|
y = fn_s(x)
|
|
|
|
self.assertTrue(y.dtype == torch.bfloat16)
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_jit_autocast_softmax_gpu(self):
|
|
def fn(x):
|
|
with torch.cuda.amp.autocast():
|
|
return torch.nn.functional.softmax(x, dim=0)
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
x = torch.rand((2, 2), dtype=torch.half).cuda()
|
|
fn_s(x)
|
|
y = fn_s(x)
|
|
|
|
self.assertTrue(y.dtype == torch.float)
|
|
|
|
def test_ignore_amp(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.mm(x, x)
|
|
|
|
inp = torch.rand([10, 10], dtype=torch.float)
|
|
foo._set_ignore_amp(True)
|
|
with torch.cpu.amp.autocast():
|
|
foo(inp)
|
|
foo(inp)
|
|
|
|
g = torch.jit.last_executed_optimized_graph()
|
|
FileCheck().check_not("_autocast_to_reduced").run(g)
|
|
|
|
class convbn(torch.nn.Module):
|
|
def __init__(self, bias_enabled=True):
|
|
super(convbn, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 64, 7, stride=2, bias=bias_enabled)
|
|
self.bn = torch.nn.BatchNorm2d(64)
|
|
|
|
def forward(self, x):
|
|
return self.bn(self.conv(x))
|
|
|
|
class TestJitTraceAutocast(JitTestCase):
|
|
def setUp(self):
|
|
super(TestJitTraceAutocast, self).setUp()
|
|
self.previous_default_dtype = torch.get_default_dtype()
|
|
torch.set_default_dtype(torch.float32)
|
|
self.models = [MnistNet(),
|
|
convbn(bias_enabled=True),
|
|
convbn(bias_enabled=False)]
|
|
self.inputs = [torch.randn(5, 1, 28, 28, device='cpu'),
|
|
torch.randn(32, 3, 224, 224, device='cpu'),
|
|
torch.randn(32, 3, 224, 224, device='cpu')]
|
|
self.previous_jit_autocast_pass = torch._C._jit_set_autocast_mode(False)
|
|
|
|
def tearDown(self):
|
|
torch._C._jit_set_autocast_mode(self.previous_jit_autocast_pass)
|
|
torch.set_default_dtype(self.previous_default_dtype)
|
|
super(TestJitTraceAutocast, self).tearDown()
|
|
|
|
def test_generate_autocast_jit_trace_model(self):
|
|
def test_generate_autocast_jit_trace_model(model, x):
|
|
model.eval()
|
|
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
|
traced_model = torch.jit.trace(model, x)
|
|
traced_model = torch.jit.freeze(traced_model)
|
|
for i in range(self.models.__len__()):
|
|
test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
|
|
|
def test_nchw_autocast_jit_trace_model(self):
|
|
def test_nchw_autocast_jit_trace_model(model, x):
|
|
model.eval()
|
|
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
|
traced_model = torch.jit.trace(model, x)
|
|
traced_model = torch.jit.freeze(traced_model)
|
|
with torch.no_grad():
|
|
y = traced_model(x.clone())
|
|
with torch.cpu.amp.autocast(), torch.no_grad():
|
|
y2 = model(x.clone())
|
|
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
|
for i in range(self.models.__len__()):
|
|
test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
|
|
|
def test_nhwc_autocast_jit_trace_model(self):
|
|
def test_nhwc_autocast_jit_trace_model(model, x):
|
|
model = model.to(memory_format=torch.channels_last)
|
|
model.eval()
|
|
with torch.cpu.amp.autocast(cache_enabled=False), torch.no_grad():
|
|
traced_model = torch.jit.trace(model, x.to(memory_format=torch.channels_last))
|
|
traced_model = torch.jit.freeze(traced_model)
|
|
with torch.no_grad():
|
|
y = traced_model(x.clone().to(memory_format=torch.channels_last))
|
|
with torch.cpu.amp.autocast(), torch.no_grad():
|
|
y2 = model(x.clone().to(memory_format=torch.channels_last))
|
|
torch.testing.assert_close(y.double(), y2.double(), rtol=1e-03, atol=1e-03)
|
|
for i in range(self.models.__len__()):
|
|
if self.inputs[i].size().__len__() == 5:
|
|
# NHWC 3D case not support yet
|
|
continue
|
|
test_nhwc_autocast_jit_trace_model(self.models[i], self.inputs[i])
|
|
|
|
def test_cat_promote(self):
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestModel, self).__init__()
|
|
|
|
def forward(self, a, b):
|
|
return torch.cat([a, b], 0)
|
|
with torch.jit.fuser("none"):
|
|
# In this testcase, we will check whether cat has done the promotion in AMP with mixed dtype inputs.
|
|
# To avoid the fusion group from TE, we will disable the fuser here.
|
|
for jit_freeze_or_not in [False, True]:
|
|
test_model = TestModel().eval()
|
|
with torch.cpu.amp.autocast(cache_enabled=False, dtype=torch.bfloat16), torch.no_grad():
|
|
a = torch.rand(24, 128, 128)
|
|
b = torch.rand(24, 128, 128, dtype=torch.bfloat16)
|
|
c = test_model(a, b)
|
|
traced = torch.jit.trace(test_model, (a, b))
|
|
if jit_freeze_or_not:
|
|
traced = torch.jit.freeze(traced)
|
|
for _ in range(3):
|
|
c2 = traced(a, b)
|
|
self.assertTrue(c.dtype, torch.float32)
|
|
self.assertTrue(c2.dtype, torch.float32)
|
|
traced_graph = traced.graph_for(a, b)
|
|
self.assertTrue(any(n.kind() == "aten::to" for n in traced_graph.nodes()))
|
|
|
|
def test_script_autocast_cpu(self):
|
|
def fn(x):
|
|
if torch.is_autocast_cpu_enabled():
|
|
return x.relu()
|
|
else:
|
|
return x.sin()
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
|
|
x = torch.rand((4, 4)) - 0.5
|
|
with torch.cpu.amp.autocast():
|
|
self.assertEqual(fn_s(x), fn(x))
|
|
|
|
with torch.cpu.amp.autocast(enabled=True):
|
|
self.assertEqual(fn_s(x), fn(x))
|
|
|
|
self.assertTrue(any(["is_autocast_cpu_enabled" in x.kind() for x in fn_s.graph.nodes()]))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
|
def test_script_autocast_cuda(self):
|
|
def fn(x):
|
|
if torch.is_autocast_enabled():
|
|
return x.relu()
|
|
else:
|
|
return x.sin()
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
|
|
x = torch.rand((4, 4)) - 0.5
|
|
with torch.cpu.amp.autocast():
|
|
self.assertEqual(fn_s(x), fn(x))
|
|
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
self.assertEqual(fn_s(x), fn(x))
|
|
|
|
self.assertTrue(any(["is_autocast_enabled" in x.kind() for x in fn_s.graph.nodes()]))
|
|
|
|
|
|
def test_scripted_aliasing(self):
|
|
# torch.is_autocast_enabled should not be able to move inside of the autocast context.
|
|
def fn(x):
|
|
if torch.is_autocast_enabled():
|
|
y = True
|
|
else:
|
|
y = False
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
z = x.relu()
|
|
return y, z
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
graph = fn_s.graph
|
|
|
|
aliasdb = graph.alias_db()
|
|
|
|
is_enabled_nodes = graph.findAllNodes("aten::is_autocast_enabled")
|
|
enter_nodes = graph.findAllNodes("prim::Enter")
|
|
|
|
self.assertEqual(len(is_enabled_nodes), 1)
|
|
self.assertEqual(len(enter_nodes), 1)
|
|
|
|
self.assertFalse(aliasdb.move_after_topologically_valid(is_enabled_nodes[0], enter_nodes[0]))
|
|
|
|
|
|
def test_script_autocast_enable_and_check(self):
|
|
def fn(x, y) -> Tuple[torch.Tensor, bool, torch.Tensor, bool, torch.Tensor, bool]:
|
|
b1 = torch.is_autocast_cpu_enabled()
|
|
v1 = torch.mm(x, y)
|
|
with torch.cpu.amp.autocast(enabled=True):
|
|
b2 = torch.is_autocast_cpu_enabled()
|
|
v2 = torch.mm(x, y)
|
|
with torch.cpu.amp.autocast(enabled=False):
|
|
b3 = torch.is_autocast_cpu_enabled()
|
|
v3 = torch.mm(x, y)
|
|
return (v1, b1, v2, b2, v3, b3)
|
|
|
|
# bx = is_autocast_cpu_enabled() result should be False iff (vx = mm(x, y)).dtype is float
|
|
def check_fn_results(arr):
|
|
[v1, b1, v2, b2, v3, b3] = arr
|
|
self.assertTrue((v1.dtype == torch.float) != b1)
|
|
self.assertTrue((v2.dtype == torch.float) != b2)
|
|
self.assertTrue((v3.dtype == torch.float) != b3)
|
|
|
|
x = torch.rand((2, 2), dtype=torch.float)
|
|
y = torch.rand((2, 2), dtype=torch.float)
|
|
|
|
fn_s = torch.jit.script(fn)
|
|
|
|
with torch.cpu.amp.autocast(enabled=False):
|
|
check_fn_results(fn(x, y))
|
|
check_fn_results(fn_s(x, y))
|
|
|
|
with torch.cpu.amp.autocast(enabled=True):
|
|
check_fn_results(fn(x, y))
|
|
check_fn_results(fn_s(x, y))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|