mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Updated derivatives for complex mm, mv, ger, bmm, triangular_solve (#45737)
Summary: This PR updates derivatives for a few functions so that `gradgradcheck` for `torch.cholesky` is passed ([ref](https://github.com/pytorch/pytorch/pull/45267#discussion_r494439967)). Some tests (that call to `bmm_cuda`) fail with with `RuntimeError: _th_bmm_out not supported on CUDAType for ComplexDouble` until PR https://github.com/pytorch/pytorch/issues/42553 is merged. Ref. https://github.com/pytorch/pytorch/issues/33152 Pull Request resolved: https://github.com/pytorch/pytorch/pull/45737 Reviewed By: bdhirsh Differential Revision: D24279917 Pulled By: anjali411 fbshipit-source-id: 7b696d2cfc2ef714332c2e3e5d207e257be67744
This commit is contained in:
parent
7f458e16ba
commit
528158af47
|
|
@ -2509,9 +2509,7 @@ class TestAutograd(TestCase):
|
|||
root = root + torch.eye(dims[-1])
|
||||
|
||||
gradcheck(func, [root, upper])
|
||||
# TODO: gradgradcheck does not work correctly yet for complex
|
||||
if not dtype.is_complex:
|
||||
gradgradcheck(func, [root, upper])
|
||||
gradgradcheck(func, [root, upper])
|
||||
|
||||
root = torch.rand(*dims, dtype=dtype)
|
||||
root = torch.matmul(root, root.transpose(-1, -2).conj())
|
||||
|
|
@ -2684,9 +2682,9 @@ class TestAutograd(TestCase):
|
|||
|
||||
@skipIfNoLapack
|
||||
def test_triangular_solve(self):
|
||||
def _test_with_size(A_dims, B_dims):
|
||||
A = torch.rand(*A_dims).requires_grad_()
|
||||
b = torch.rand(*B_dims).requires_grad_()
|
||||
def run_test(A_dims, B_dims, dtype):
|
||||
A = torch.rand(*A_dims, dtype=dtype).requires_grad_()
|
||||
b = torch.rand(*B_dims, dtype=dtype).requires_grad_()
|
||||
|
||||
for upper, transpose, unitriangular in product((True, False), repeat=3):
|
||||
def func(A, b):
|
||||
|
|
@ -2695,10 +2693,11 @@ class TestAutograd(TestCase):
|
|||
gradcheck(func, [A, b])
|
||||
gradgradcheck(func, [A, b])
|
||||
|
||||
_test_with_size((3, 3), (3, 4))
|
||||
_test_with_size((3, 3), (3, 2))
|
||||
_test_with_size((2, 3, 3), (2, 3, 4))
|
||||
_test_with_size((2, 3, 3), (2, 3, 2))
|
||||
for dtype in (torch.double, torch.cdouble):
|
||||
run_test((3, 3), (3, 4), dtype)
|
||||
run_test((3, 3), (3, 2), dtype)
|
||||
run_test((2, 3, 3), (2, 3, 4), dtype)
|
||||
run_test((2, 3, 3), (2, 3, 2), dtype)
|
||||
|
||||
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
|
||||
def test_fft_ifft_rfft_irfft(self):
|
||||
|
|
@ -4833,7 +4832,11 @@ complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone'
|
|||
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
|
||||
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
|
||||
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
|
||||
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests
|
||||
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split',
|
||||
'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests
|
||||
|
||||
# this list corresponds to cases that are not currently implemented
|
||||
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
|
||||
|
||||
# TODO(@anjali411): add tests for 'sub', 'div
|
||||
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
|
||||
|
|
@ -5019,6 +5022,11 @@ def add_test(
|
|||
for skip in skipTestIf:
|
||||
do_test = skip(do_test)
|
||||
|
||||
# TODO: remove this once tests from skip_cuda_list work
|
||||
do_test = skipCUDAIf(
|
||||
any(skip_test in test_name for skip_test in skip_cuda_list),
|
||||
"not implemented for CUDA yet")(do_test)
|
||||
|
||||
setattr(TestAutogradDeviceType, test_name, do_test)
|
||||
|
||||
class TestAutogradComplex(TestCase):
|
||||
|
|
|
|||
|
|
@ -275,8 +275,8 @@
|
|||
self: zeros_like(grad)
|
||||
|
||||
- name: bmm(Tensor self, Tensor mat2) -> Tensor
|
||||
self: grad.bmm(mat2.transpose(1, 2))
|
||||
mat2: self.transpose(1, 2).bmm(grad)
|
||||
self: grad.bmm(mat2.transpose(1, 2).conj())
|
||||
mat2: self.transpose(1, 2).conj().bmm(grad)
|
||||
|
||||
- name: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor
|
||||
self: at::_bmm(grad, mat2.transpose(1, 2), deterministic)
|
||||
|
|
@ -498,8 +498,8 @@
|
|||
self: not_implemented("geqrf")
|
||||
|
||||
- name: ger(Tensor self, Tensor vec2) -> Tensor
|
||||
self: grad.mv(vec2)
|
||||
vec2: grad.t().mv(self)
|
||||
self: grad.mv(vec2.conj())
|
||||
vec2: grad.t().mv(self.conj())
|
||||
|
||||
- name: indices(Tensor(a) self) -> Tensor(a)
|
||||
output_differentiability: [False]
|
||||
|
|
@ -749,8 +749,8 @@
|
|||
self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
|
||||
|
||||
- name: mv(Tensor self, Tensor vec) -> Tensor
|
||||
self: grad.ger(vec)
|
||||
vec: self.t().mv(grad)
|
||||
self: grad.ger(vec.conj())
|
||||
vec: self.conj().t().mv(grad)
|
||||
|
||||
- name: mvlgamma(Tensor self, int p) -> Tensor
|
||||
self: mvlgamma_backward(grad, self, p)
|
||||
|
|
|
|||
|
|
@ -164,7 +164,8 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||
'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
|
||||
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
|
||||
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
|
||||
'dot', 'vdot', 'cholesky'
|
||||
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
|
||||
'bmm', 'diagonal'
|
||||
}
|
||||
|
||||
# Some operators invalidate the grad_accumulator. Let's reset it.
|
||||
|
|
|
|||
|
|
@ -530,9 +530,9 @@ Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor &
|
|||
at::IntArrayRef sizes = mat1.sizes();
|
||||
at::IntArrayRef strides = mat1.strides();
|
||||
if (strides[0] == 1 && strides[1] == sizes[0]) {
|
||||
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
|
||||
return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha);
|
||||
} else {
|
||||
return maybe_multiply(grad.mm(mat2.t()), alpha);
|
||||
return maybe_multiply(grad.mm(mat2.t().conj()), alpha);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -550,9 +550,9 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si
|
|||
at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
|
||||
return r;
|
||||
}
|
||||
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
|
||||
return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha);
|
||||
} else {
|
||||
return maybe_multiply(mat1.t().mm(grad), alpha);
|
||||
return maybe_multiply(mat1.t().conj().mm(grad), alpha);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2123,9 +2123,9 @@ std::tuple<Tensor, Tensor> triangular_solve_backward(
|
|||
Tensor grad_b, grad_a;
|
||||
if (grad_x.defined() || grad_m.defined()) {
|
||||
if (grad_x.defined()) {
|
||||
grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
|
||||
grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
|
||||
if (output_mask[1]) {
|
||||
grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
|
||||
grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj());
|
||||
if (upper) {
|
||||
grad_a = grad_a.triu((int) unitriangular);
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user