Updated derivatives for complex mm, mv, ger, bmm, triangular_solve (#45737)

Summary:
This PR updates derivatives for a few functions so that `gradgradcheck` for `torch.cholesky` is passed ([ref](https://github.com/pytorch/pytorch/pull/45267#discussion_r494439967)).

Some tests (that call to `bmm_cuda`) fail with with `RuntimeError: _th_bmm_out not supported on CUDAType for ComplexDouble`
until PR https://github.com/pytorch/pytorch/issues/42553 is merged.

Ref. https://github.com/pytorch/pytorch/issues/33152

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45737

Reviewed By: bdhirsh

Differential Revision: D24279917

Pulled By: anjali411

fbshipit-source-id: 7b696d2cfc2ef714332c2e3e5d207e257be67744
This commit is contained in:
Ivan Yashchuk 2020-10-15 11:25:37 -07:00 committed by Facebook GitHub Bot
parent 7f458e16ba
commit 528158af47
4 changed files with 33 additions and 24 deletions

View File

@ -2509,8 +2509,6 @@ class TestAutograd(TestCase):
root = root + torch.eye(dims[-1])
gradcheck(func, [root, upper])
# TODO: gradgradcheck does not work correctly yet for complex
if not dtype.is_complex:
gradgradcheck(func, [root, upper])
root = torch.rand(*dims, dtype=dtype)
@ -2684,9 +2682,9 @@ class TestAutograd(TestCase):
@skipIfNoLapack
def test_triangular_solve(self):
def _test_with_size(A_dims, B_dims):
A = torch.rand(*A_dims).requires_grad_()
b = torch.rand(*B_dims).requires_grad_()
def run_test(A_dims, B_dims, dtype):
A = torch.rand(*A_dims, dtype=dtype).requires_grad_()
b = torch.rand(*B_dims, dtype=dtype).requires_grad_()
for upper, transpose, unitriangular in product((True, False), repeat=3):
def func(A, b):
@ -2695,10 +2693,11 @@ class TestAutograd(TestCase):
gradcheck(func, [A, b])
gradgradcheck(func, [A, b])
_test_with_size((3, 3), (3, 4))
_test_with_size((3, 3), (3, 2))
_test_with_size((2, 3, 3), (2, 3, 4))
_test_with_size((2, 3, 3), (2, 3, 2))
for dtype in (torch.double, torch.cdouble):
run_test((3, 3), (3, 4), dtype)
run_test((3, 3), (3, 2), dtype)
run_test((2, 3, 3), (2, 3, 4), dtype)
run_test((2, 3, 3), (2, 3, 2), dtype)
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
def test_fft_ifft_rfft_irfft(self):
@ -4833,7 +4832,11 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split',
'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests
# this list corresponds to cases that are not currently implemented
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
# TODO(@anjali411): add tests for 'sub', 'div
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
@ -5019,6 +5022,11 @@ def add_test(
for skip in skipTestIf:
do_test = skip(do_test)
# TODO: remove this once tests from skip_cuda_list work
do_test = skipCUDAIf(
any(skip_test in test_name for skip_test in skip_cuda_list),
"not implemented for CUDA yet")(do_test)
setattr(TestAutogradDeviceType, test_name, do_test)
class TestAutogradComplex(TestCase):

View File

@ -275,8 +275,8 @@
self: zeros_like(grad)
- name: bmm(Tensor self, Tensor mat2) -> Tensor
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
self: grad.bmm(mat2.transpose(1, 2).conj())
mat2: self.transpose(1, 2).conj().bmm(grad)
- name: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor
self: at::_bmm(grad, mat2.transpose(1, 2), deterministic)
@ -498,8 +498,8 @@
self: not_implemented("geqrf")
- name: ger(Tensor self, Tensor vec2) -> Tensor
self: grad.mv(vec2)
vec2: grad.t().mv(self)
self: grad.mv(vec2.conj())
vec2: grad.t().mv(self.conj())
- name: indices(Tensor(a) self) -> Tensor(a)
output_differentiability: [False]
@ -749,8 +749,8 @@
self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
- name: mv(Tensor self, Tensor vec) -> Tensor
self: grad.ger(vec)
vec: self.t().mv(grad)
self: grad.ger(vec.conj())
vec: self.conj().t().mv(grad)
- name: mvlgamma(Tensor self, int p) -> Tensor
self: mvlgamma_backward(grad, self, p)

View File

@ -164,7 +164,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky'
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal'
}
# Some operators invalidate the grad_accumulator. Let's reset it.

View File

@ -530,9 +530,9 @@ Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor &
at::IntArrayRef sizes = mat1.sizes();
at::IntArrayRef strides = mat1.strides();
if (strides[0] == 1 && strides[1] == sizes[0]) {
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha);
} else {
return maybe_multiply(grad.mm(mat2.t()), alpha);
return maybe_multiply(grad.mm(mat2.t().conj()), alpha);
}
}
@ -550,9 +550,9 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si
at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
return r;
}
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha);
} else {
return maybe_multiply(mat1.t().mm(grad), alpha);
return maybe_multiply(mat1.t().conj().mm(grad), alpha);
}
}
@ -2123,9 +2123,9 @@ std::tuple<Tensor, Tensor> triangular_solve_backward(
Tensor grad_b, grad_a;
if (grad_x.defined() || grad_m.defined()) {
if (grad_x.defined()) {
grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj());
if (upper) {
grad_a = grad_a.triu((int) unitriangular);
} else {