mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
autodiff fix for autocast_to_xxx (#67648)
Summary: Fixes autocast + autodiff issue where `RuntimeError: grad_inputs.size() == node->inputs().size()INTERNAL ASSERT FAILED at "../torch/csrc/jit/runtime/autodiff.cpp":426, please report a bug to PyTorch.` Pull Request resolved: https://github.com/pytorch/pytorch/pull/67648 Reviewed By: cpuhrsch Differential Revision: D32083227 Pulled By: davidberard98 fbshipit-source-id: edf526cff4ec21874ae35ec730d13c250073e10c
This commit is contained in:
parent
9269080b47
commit
ee7412dd29
|
|
@ -590,5 +590,38 @@ class TestAutocast(JitTestCase):
|
|||
# no cast op should be observed when executing outside autocast context
|
||||
self._test_autocast(t, None, cpu0, cpu1, cuda0, cuda1)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_autocast_autodiff(self):
|
||||
def t(t0, t1):
|
||||
o = torch.mm(t0, t1)
|
||||
return o.relu()
|
||||
|
||||
jit_t = torch.jit.script(t)
|
||||
t0 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
|
||||
t1 = torch.randn(5, 5, device="cuda", dtype=torch.float32).requires_grad_()
|
||||
|
||||
# run optimization
|
||||
for i in range(5):
|
||||
with torch.autocast("cuda", torch.float16):
|
||||
jit_o = jit_t(t0, t1)
|
||||
jit_o.sum().backward()
|
||||
|
||||
t0.grad = None
|
||||
t1.grad = None
|
||||
ref_t0 = t0.detach().requires_grad_()
|
||||
ref_t1 = t1.detach().requires_grad_()
|
||||
|
||||
with torch.autocast("cuda", torch.float16):
|
||||
o = t(ref_t0, ref_t1)
|
||||
jit_o = jit_t(t0, t1)
|
||||
jit_o.sum().backward()
|
||||
o.sum().backward()
|
||||
self.assertEqual(o, jit_o)
|
||||
self.assertEqual(t0.grad, ref_t0.grad)
|
||||
self.assertEqual(t1.grad, ref_t1.grad)
|
||||
self.assertEqual(o.dtype, jit_o.dtype)
|
||||
self.assertEqual(t0.grad.dtype, ref_t0.grad.dtype)
|
||||
self.assertEqual(t1.grad.dtype, ref_t1.grad.dtype)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -482,7 +482,7 @@ const std::vector<std::string> functions = {
|
|||
def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
|
||||
self_dtype = self.dtype
|
||||
def backward(grad_output):
|
||||
return grad_output.to(self_dtype)
|
||||
return grad_output.to(self_dtype), None, None
|
||||
|
||||
return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
|
||||
|
||||
|
|
@ -493,7 +493,7 @@ const std::vector<std::string> functions = {
|
|||
cpu_dtype : int):
|
||||
self_dtype = self.dtype
|
||||
def backward(grad_output):
|
||||
return grad_output.to(self_dtype)
|
||||
return grad_output.to(self_dtype), None, None, None, None
|
||||
|
||||
return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user