diff --git a/test/test_autograd.py b/test/test_autograd.py index 91da9e79b88..d5dc874812a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -2509,9 +2509,7 @@ class TestAutograd(TestCase): root = root + torch.eye(dims[-1]) gradcheck(func, [root, upper]) - # TODO: gradgradcheck does not work correctly yet for complex - if not dtype.is_complex: - gradgradcheck(func, [root, upper]) + gradgradcheck(func, [root, upper]) root = torch.rand(*dims, dtype=dtype) root = torch.matmul(root, root.transpose(-1, -2).conj()) @@ -2684,9 +2682,9 @@ class TestAutograd(TestCase): @skipIfNoLapack def test_triangular_solve(self): - def _test_with_size(A_dims, B_dims): - A = torch.rand(*A_dims).requires_grad_() - b = torch.rand(*B_dims).requires_grad_() + def run_test(A_dims, B_dims, dtype): + A = torch.rand(*A_dims, dtype=dtype).requires_grad_() + b = torch.rand(*B_dims, dtype=dtype).requires_grad_() for upper, transpose, unitriangular in product((True, False), repeat=3): def func(A, b): @@ -2695,10 +2693,11 @@ class TestAutograd(TestCase): gradcheck(func, [A, b]) gradgradcheck(func, [A, b]) - _test_with_size((3, 3), (3, 4)) - _test_with_size((3, 3), (3, 2)) - _test_with_size((2, 3, 3), (2, 3, 4)) - _test_with_size((2, 3, 3), (2, 3, 2)) + for dtype in (torch.double, torch.cdouble): + run_test((3, 3), (3, 4), dtype) + run_test((3, 3), (3, 2), dtype) + run_test((2, 3, 3), (2, 3, 4), dtype) + run_test((2, 3, 3), (2, 3, 2), dtype) @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") def test_fft_ifft_rfft_irfft(self): @@ -4833,7 +4832,11 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone' 'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu', 'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh', - 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests + 'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', + 'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests + +# this list corresponds to cases that are not currently implemented +skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex'] # TODO(@anjali411): add tests for 'sub', 'div # TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411 @@ -5019,6 +5022,11 @@ def add_test( for skip in skipTestIf: do_test = skip(do_test) + # TODO: remove this once tests from skip_cuda_list work + do_test = skipCUDAIf( + any(skip_test in test_name for skip_test in skip_cuda_list), + "not implemented for CUDA yet")(do_test) + setattr(TestAutogradDeviceType, test_name, do_test) class TestAutogradComplex(TestCase): diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9c6cd4c578d..eddc30a8604 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -275,8 +275,8 @@ self: zeros_like(grad) - name: bmm(Tensor self, Tensor mat2) -> Tensor - self: grad.bmm(mat2.transpose(1, 2)) - mat2: self.transpose(1, 2).bmm(grad) + self: grad.bmm(mat2.transpose(1, 2).conj()) + mat2: self.transpose(1, 2).conj().bmm(grad) - name: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor self: at::_bmm(grad, mat2.transpose(1, 2), deterministic) @@ -498,8 +498,8 @@ self: not_implemented("geqrf") - name: ger(Tensor self, Tensor vec2) -> Tensor - self: grad.mv(vec2) - vec2: grad.t().mv(self) + self: grad.mv(vec2.conj()) + vec2: grad.t().mv(self.conj()) - name: indices(Tensor(a) self) -> Tensor(a) output_differentiability: [False] @@ -749,8 +749,8 @@ self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type()) - name: mv(Tensor self, Tensor vec) -> Tensor - self: grad.ger(vec) - vec: self.t().mv(grad) + self: grad.ger(vec.conj()) + vec: self.conj().t().mv(grad) - name: mvlgamma(Tensor self, int p) -> Tensor self: mvlgamma_backward(grad, self, p) diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 72fac453413..3cbd585a09c 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -164,7 +164,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = { 'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex', 'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd', 'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward', - 'dot', 'vdot', 'cholesky' + 'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger', + 'bmm', 'diagonal' } # Some operators invalidate the grad_accumulator. Let's reset it. diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 1314a98e956..f807e2b1dff 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -530,9 +530,9 @@ Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor & at::IntArrayRef sizes = mat1.sizes(); at::IntArrayRef strides = mat1.strides(); if (strides[0] == 1 && strides[1] == sizes[0]) { - return maybe_multiply(mat2.mm(grad.t()).t(), alpha); + return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha); } else { - return maybe_multiply(grad.mm(mat2.t()), alpha); + return maybe_multiply(grad.mm(mat2.t().conj()), alpha); } } @@ -550,9 +550,9 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si at::addmm_out(r, t, mat1.t(), grad, alpha, 1); return r; } - return maybe_multiply(grad.t().mm(mat1).t(), alpha); + return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha); } else { - return maybe_multiply(mat1.t().mm(grad), alpha); + return maybe_multiply(mat1.t().conj().mm(grad), alpha); } } @@ -2123,9 +2123,9 @@ std::tuple triangular_solve_backward( Tensor grad_b, grad_a; if (grad_x.defined() || grad_m.defined()) { if (grad_x.defined()) { - grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular)); + grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular)); if (output_mask[1]) { - grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2)); + grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj()); if (upper) { grad_a = grad_a.triu((int) unitriangular); } else {