autodiff fix for autocast_to_xxx (#67648)

Summary:
Fixes autocast + autodiff issue where `RuntimeError: grad_inputs.size() == node->inputs().size()INTERNAL ASSERT FAILED at "../torch/csrc/jit/runtime/autodiff.cpp":426, please report a bug to PyTorch.`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67648

Reviewed By: cpuhrsch

Differential Revision: D32083227

Pulled By: davidberard98

fbshipit-source-id: edf526cff4ec21874ae35ec730d13c250073e10c
This commit is contained in:
jiej 2021-11-05 10:44:08 -07:00 committed by Facebook GitHub Bot
parent 9269080b47
commit ee7412dd29
2 changed files with 35 additions and 2 deletions

View File

@ -590,5 +590,38 @@ class TestAutocast(JitTestCase):
# no cast op should be observed when executing outside autocast context
self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_autocast_autodiff(self):
def t(t0, t1):
o = torch.mm(t0, t1)
return o.relu()
jit_t = torch.jit.script(t)
t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
# run optimization
for i in range(5):
with torch.autocast("cuda", torch.float16):
jit_o = jit_t(t0, t1)
jit_o.sum().backward()
t0.grad = None
t1.grad = None
ref_t0 = t0.detach().requires_grad_()
ref_t1 = t1.detach().requires_grad_()
with torch.autocast("cuda", torch.float16):
o = t(ref_t0, ref_t1)
jit_o = jit_t(t0, t1)
jit_o.sum().backward()
o.sum().backward()
self.assertEqual(o, jit_o)
self.assertEqual(t0.grad, ref_t0.grad)
self.assertEqual(t1.grad, ref_t1.grad)
self.assertEqual(o.dtype, jit_o.dtype)
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
if __name__ == '__main__':
run_tests()

View File

@ -482,7 +482,7 @@ const std::vector<std::string> functions = {
def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
self_dtype = self.dtype
def backward(grad_output):
return grad_output.to(self_dtype)
return grad_output.to(self_dtype), None, None
return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
@ -493,7 +493,7 @@ const std::vector<std::string> functions = {
cpu_dtype : int):
self_dtype = self.dtype
def backward(grad_output):
return grad_output.to(self_dtype)
return grad_output.to(self_dtype), None, None, None, None
return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward