Add determinant function on variable; Add backward on svd (#3816)

* determinant on variable

* svd bwd
This commit is contained in:
Tongzhou Wang 2017-12-01 13:22:46 -05:00 committed by Soumith Chintala
parent 80c8635a7e
commit c681b03d37
20 changed files with 424 additions and 20 deletions

View File

@ -207,7 +207,7 @@ If you are working on the CUDA code, here are some useful CUDA debugging tips:
1. `CUDA_DEBUG=1` will enable CUDA debugging symbols (-g -G). This is particularly
helpful in debugging device code. However, it will slow down the build process,
so use wisely.
2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debuging friends. Unlike`gdb`,
2. `cuda-gdb` and `cuda-memcheck` are your best CUDA debugging friends. Unlike`gdb`,
`cuda-gdb` can display actual values in a CUDA tensor (rather than all zeros).

View File

@ -3504,6 +3504,7 @@
- Double
backends:
- CPU
- CUDA
variants:
- method
- function

View File

@ -270,6 +270,52 @@ Tensor & unsqueeze_(Tensor& self, int64_t dim) {
return self.as_strided_(std::get<0>(g), std::get<1>(g));
}
// For backward, we save svd.
// http://www.ics.forth.gr/cvrl/publications/conferences/2000_eccv_SVD_jacobian.pdf
// But instead of gesvd SVD A = U(A) Sig(A) V(A)^T, which doesn't specify signs
// of determinants of U and V, we consider det(A) = \prod Sig_(A), where
// 1. A = U_(A) Sig_(A) V(A)^T
// 2. Sig_(A) and U_(A) can be different in signs in first row/col from
// their counterparts so that U_(A) * V_(A) have +1 determinant
std::tuple<Tensor, Tensor, Tensor, Tensor> _det_with_svd(const Tensor& self) {
if (!at::isFloatingType(self.type().scalarType()) ||
self.dim() != 2 || self.size(0) != self.size(1)) {
std::ostringstream ss;
ss << "det(" << self.type() << "{" << self.sizes() << "}): expected a 2D"
<< "square tensor of floating types";
throw std::runtime_error(ss.str());
}
// check symmetric
bool symmetric = self.equal(self.transpose(0, 1));
auto svd = self.svd(true);
auto sigma = std::get<1>(svd);
auto u = std::get<0>(svd);
auto v = std::get<2>(svd);
auto det = sigma.prod();
if (!symmetric) {
auto qr = self.geqrf();
auto a = std::get<0>(qr);
auto tau = std::get<1>(qr);
// non-zero values in tau represent Householder reflectors, which has -1 det
int64_t num_reflectors = tau.nonzero().size(0);
auto qr_det = a.diag().prod();
if (num_reflectors % 2 == 1) {
qr_det = -qr_det;
}
det = qr_det; // QR is more stable than svd, so use it anyways
if ((qr_det < 0).any() ^ (det < 0).any()) { // if different sign
u.narrow(1, 0, 1).mul_(-1);
sigma.narrow(0, 0, 1).mul_(-1);
}
}
return std::make_tuple(det, u, sigma, v);
}
Tensor det(const Tensor& self) {
return std::get<0>(self._det_with_svd());
}
Tensor stack(TensorList tensors, int64_t dim) {
if (tensors.size() == 0) {
throw std::runtime_error("stack expects a non-empty TensorList");

View File

@ -76,6 +76,10 @@
- func: unsqueeze_(Tensor self, int64_t dim) -> Tensor
- func: _det_with_svd(Tensor self) -> (Tensor, Tensor, Tensor, Tensor)
- func: det(Tensor self) -> Tensor
- func: stack(TensorList tensors, int64_t dim=0) -> Tensor
variants: function

View File

@ -584,6 +584,51 @@ THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *b, TH
#endif
}
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_)
{
#ifdef USE_MAGMA
THArgCheck(a_->nDimension == 2, 2, "A should be 2 dimensional");
THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
int64_t m = a->size[0];
int64_t n = a->size[1];
int64_t k = (m < n ? m : n);
#ifdef MAGMA_V2
#if defined(THC_REAL_IS_FLOAT)
int64_t nb = magma_get_sgeqrf_nb(m, n);
#else
int64_t nb = magma_get_dgeqrf_nb(m, n);
#endif
#else
#if defined(THC_REAL_IS_FLOAT)
int64_t nb = magma_get_sgeqrf_nb(m);
#else
int64_t nb = magma_get_dgeqrf_nb(m);
#endif
#endif
real *rtau_data = th_magma_malloc_pinned<real>(k);
real *a_data = THCTensor_(data)(state, a);
int info;
#if defined(THC_REAL_IS_FLOAT)
magma_sgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
#else
magma_dgeqrf2_gpu(m, n, a_data, m, rtau_data, &info);
#endif
if (info != 0)
THError("MAGMA geqrf2 : Argument %d : illegal value.", -info);
THCTensor_(freeCopyTo)(state, a, ra_);
THCTensor_(copyArray1d)(state, rtau_, rtau_data, k);
magma_free_pinned(rtau_data);
#else
THError(NoMagma(geqrf));
#endif
}
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a_)
{
#ifdef USE_MAGMA
@ -614,6 +659,11 @@ THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THC
real *work_data = THCTensor_(data)(state, work);
int info;
// We need to call two different versions of ?geqrf:
// ?geqrf_gpu allows fast computation of Q via ?orqrf_gpu, but doesn't give
// R properly. Note that the MAGMA documentation for this method is wrong.
// http://icl.cs.utk.edu/magma/forum/viewtopic.php?f=2&t=1015&p=2800&hilit=geqrf_gpu#p2800
// ?geqrf2_gpu gives correct R, but doesn't allow computation of Q via ?orqrf_gpu
#if defined(THC_REAL_IS_FLOAT)
magma_sgeqrf2_gpu(m, n, a_data, m, tau_data, &info);
#else

View File

@ -15,9 +15,9 @@ THC_API void THCTensor_(getri)(THCState *state, THCTensor *ra_, THCTensor *a);
THC_API void THCTensor_(potri)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrf)(THCState *state, THCTensor *ra_, THCTensor *a, const char *uplo);
THC_API void THCTensor_(potrs)(THCState *state, THCTensor *rb_, THCTensor *a, THCTensor *b, const char *uplo);
THC_API void THCTensor_(geqrf)(THCState *state, THCTensor *ra_, THCTensor *rtau_, THCTensor *a_);
THC_API void THCTensor_(qr)(THCState *state, THCTensor *rq_, THCTensor *rr_, THCTensor *a);
#endif // defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
#endif

View File

@ -188,6 +188,7 @@ BLAS and LAPACK Operations
.. autofunction:: ger
.. autofunction:: gesv
.. autofunction:: inverse
.. autofunction:: det
.. autofunction:: matmul
.. autofunction:: mm
.. autofunction:: mv

View File

@ -1898,6 +1898,30 @@ def _make_cov(S):
return torch.mm(L, L.t())
def random_square_matrix_of_rank(l, rank):
assert rank <= l
A = torch.randn(l, l)
u, s, v = A.svd()
for i in range(l):
if i >= rank:
s[i] = 0
elif s[i] == 0:
s[i] = 1
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
def random_symmetric_matrix(l):
A = torch.randn(l, l)
return A.mm(A.transpose(0, 1))
def random_fullrank_matrix_distinct_singular_value(l):
A = torch.randn(l, l)
u, _, v = A.svd()
s = torch.arange(1, l + 1).mul_(1.0 / (l + 1))
return u.mm(torch.diag(s)).mm(v.transpose(0, 1))
class dont_convert(tuple):
pass
@ -1906,7 +1930,6 @@ L = 20
M = 10
S = 5
# (name, size, args...)
method_tests = [
('add', (S, S, S), ((S, S, S),)),
@ -2166,6 +2189,13 @@ method_tests = [
('index_copy', (S, S), (0, index_perm_variable(2, S), (2, S)), 'dim', [0]),
('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
('inverse', (S, S), (), '', (), [skipIfNoLapack]),
('det', (S, S), (), '', (), [skipIfNoLapack]),
('det', lambda: random_symmetric_matrix(S), (), 'symmetric', (), [skipIfNoLapack]),
('det', lambda: random_square_matrix_of_rank(S, S - 2), (), 'dim2_null', (), [skipIfNoLapack]),
('det', lambda: random_square_matrix_of_rank(S, 1), (), 'rank1', (), [skipIfNoLapack]),
('det', lambda: random_square_matrix_of_rank(S, 2), (), 'rank2', (), [skipIfNoLapack]),
('det', lambda: random_fullrank_matrix_distinct_singular_value(S), (), 'distinct_postive_s', (), [skipIfNoLapack]),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(S), (), '', (), [skipIfNoLapack]),
('gesv', (S, S), ((S, S),), '', (), [skipIfNoLapack]),
('potrf', _make_cov(S), (True,), '', (), [skipIfNoLapack]),
('eq', (S, S, S), ((S, S, S),)),
@ -2303,6 +2333,8 @@ def create_input(call_args, requires_grad=True, non_contiguous=False):
return Variable(maybe_non_contig(arg), requires_grad=requires_grad)
elif isinstance(arg, Variable) and non_contiguous:
return Variable(maybe_non_contig(arg.data), requires_grad=arg.requires_grad)
elif callable(arg):
return map_arg(arg())
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
@ -2339,6 +2371,19 @@ EXCLUDE_FUNCTIONAL = {
EXCLUDE_GRADCHECK = {
'potrf'
}
EXCLUDE_GRADGRADCHECK = {
'svd'
}
EXCLUDE_GRADGRADCHECK_BY_TEST_NAME = {
# Some of the following det ones pass because random matrix has full rank
# with high probability. But we can't rely on this. So only test gradgrad on
# test_det_distinct_postive_s.
'test_det',
'test_det_symmetric',
'test_det_dim2_null',
'test_det_rank1',
'test_det_rank2'
}
def exclude_tensor_method(name, test_name):
@ -2359,6 +2404,7 @@ def exclude_tensor_method(name, test_name):
'resize_as',
'scatter',
'scatter_add',
'det',
}
if test_name in exclude_all_tensor_method_by_test_name:
return True
@ -2390,9 +2436,11 @@ def gradgradcheck_method_precision_override(test_name):
return override
def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_variable, input_variables):
def run_grad_and_gradgrad_checks(test_case, name, test_name, apply_method, output_variable,
input_variables, run_gradgradcheck=True):
test_case.assertTrue(gradcheck(apply_method, input_variables, eps=1e-6, atol=PRECISION))
if name in EXCLUDE_GRADGRADCHECK or test_name in EXCLUDE_GRADGRADCHECK_BY_TEST_NAME:
return
grad_y = generate_gradoutput(output_variable, non_contiguous=True)
gradgradcheck_precision_override = gradgradcheck_method_precision_override(test_name)
if gradgradcheck_precision_override is not None:
@ -2400,7 +2448,7 @@ def run_grad_and_gradgrad_checks(test_case, test_name, apply_method, output_vari
rtol = gradgradcheck_precision_override['rtol']
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y, atol=atol, rtol=rtol))
else:
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y,))
test_case.assertTrue(gradgradcheck(apply_method, input_variables, grad_y))
def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
@ -2413,7 +2461,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
test_case.assertEqual(unpack_variables(output_variable), output_tensor)
if run_grad_checks:
run_grad_and_gradgrad_checks(test_case, test_name, apply_fn,
run_grad_and_gradgrad_checks(test_case, name, test_name, apply_fn,
output_variable, f_args_variable)
self_variable = f_args_variable[0]
@ -2457,7 +2505,7 @@ for test in method_tests:
# TODO: check that both have changed after adding all inplace ops
if not is_inplace and name not in EXCLUDE_GRADCHECK:
run_grad_and_gradgrad_checks(self, test_name,
run_grad_and_gradgrad_checks(self, name, test_name,
lambda *inputs: getattr(inputs[0], name)(*inputs[1:]),
output_variable, (self_variable,) + args_variable)

View File

@ -316,7 +316,8 @@ tests = [
('qr', small_2d_lapack_fat, lambda t: [], 'fat', float_types),
('qr', large_2d_lapack, lambda t: [], 'big', float_types),
('inverse', new_t(20, 20), lambda t: [], None, float_types),
('geqrf', new_t(20, 20), lambda t: [], None, float_types),
# TODO: add det to here once Variable and Tensor are the same thing
]
# TODO: random functions, cat, gather, scatter, index*, masked*,
@ -938,6 +939,10 @@ class TestCuda(TestCase):
def _select_broadcastable_dims(dims_full=None):
return TestTorch._select_broadcastable_dims(dims_full)
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
def test_det(self):
TestTorch._test_det(self, lambda t: t.cuda())
def test_broadcast(self):
TestTorch._test_broadcast(self, lambda t: t.cuda())

View File

@ -2471,6 +2471,97 @@ class TestTorch(TestCase):
self.assertFalse(MII.is_contiguous(), 'MII is contiguous')
self.assertEqual(MII, MI, 0, 'inverse value in-place')
@staticmethod
def _test_det(self, conv_fn):
def reference_det(M):
# naive row reduction
M = M.clone()
l = M.size(0)
multiplier = 1
for i in range(l):
if M[i, 0] != 0:
if i != 0:
M[0], M[i] = M[i], M[0]
multiplier = -1
break
else:
return 0
for i in range(1, l):
row = M[i]
for j in range(i):
row -= row[j] / M[j, j] * M[j]
M[i] = row
return M.diag().prod() * multiplier
# TODO: remove Variable wrapper once Variable and Tensor are the same
Variable = torch.autograd.Variable
eye_det = Variable(conv_fn(torch.eye(5))).det()
self.assertEqual(eye_det, eye_det.clone().fill_(1), 1e-8, 'determinant of identity')
def test(M):
M = conv_fn(M)
var_M = Variable(M)
M_det = var_M.det().data
self.assertEqual(M_det, M_det.clone().fill_(reference_det(M)), 1e-8, 'determinant')
self.assertEqual(M_det, var_M.inverse().det().data.pow_(-1), 1e-8, 'determinant after transpose')
self.assertEqual(M_det, var_M.transpose(0, 1).det().data, 1e-8, 'determinant after transpose')
for x in [0, 2, 4]:
for scale in [-2, -0.1, 0, 10]:
target = M_det * scale
# dim 0
M_clone = M.clone()
M_clone[:, x] *= scale
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant after scaling a row')
# dim 1
M_clone = M.clone()
M_clone[x, :] *= scale
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant after scaling a column')
for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
assert x1 != x2, 'x1 and x2 needs to be different for this test'
target = M_det.clone().zero_()
# dim 0
M_clone = M.clone()
M_clone[:, x2] = M_clone[:, x1]
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant when two rows are same')
# dim 1
M_clone = M.clone()
M_clone[x2, :] = M_clone[x1, :]
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant when two columns are same')
for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
target = -M_det * scale1 * scale2
# dim 0
M_clone = M.clone()
t = M_clone[:, x1] * scale1
M_clone[:, x1] += M_clone[:, x2] * scale2
M_clone[:, x2] = t
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant after exchanging rows')
# dim 1
M_clone = M.clone()
t = M_clone[x1, :] * scale1
M_clone[x1, :] += M_clone[x2, :] * scale2
M_clone[x2, :] = t
det = Variable(M_clone).det().data
self.assertEqual(target, det, 1e-8, 'determinant after exchanging columns')
test(torch.randn(5, 5))
r = torch.randn(5, 5)
test(r.mm(r.transpose(0, 1))) # symmetric
test(torch.randn(5, 5, 5)[:, 2, :]) # non-contiguous
@skipIfNoLapack
def test_det(self):
self._test_det(self, lambda x: x)
@unittest.skip("Not implemented yet")
def test_conv2(self):
x = torch.rand(math.floor(torch.uniform(50, 100)), math.floor(torch.uniform(50, 100)))

View File

@ -16,16 +16,19 @@
# we are going to left-multiply. When the forward returns multiple
# outputs, 'grad' always refers to the first output; you can refer
# to other outputs using 'grads'
# - Any of the input arguments, tensor or non-tensor
# - 'output', representing the result of evaluating the forward
# expression
# - Any of the input arguments, tensor or non-tensor, including
# argument names tha only appear in Declarations.cwrap, e.g. 'output'.
# - 'result', representing the result of evaluating the forward
# expression for ATen native function decalarations. If the forward
# expression outputs a tuple, use 'resultX' instead to access the
# X-th entry
# - 'grad_input_mask', a std::array<bool, n> (where n is the number
# of differentiable inputs), specifying which inputs actually
# require gradient. (This is only available when multiple
# derivatives are being computed by a single formula.)
#
# If you need a complex expression, e.g., with local variables,
# write a _backward function in tools/autograd/templates/Function.cpp
# write a _backward function in tools/autograd/templates/Functions.cpp
# and invoke it from here. By the way, go read
# https://github.com/zdevito/ATen/issues/163; this describes an
# important hazard that occurs when porting backwards from Python to C++
@ -165,6 +168,9 @@
- name: data_ptr # fallthrough
- name: _det_with_svd(Tensor self)
self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)
- name: diag(Tensor self, int64_t diagonal)
self: grad.diag(diagonal)
@ -443,7 +449,8 @@
# TODO: complicated
# - name: prod(Tensor self, int64_t dim, bool keepdim)
# - name: prod(Tensor self)
- name: prod(Tensor self)
self: not_implemented("prod")
- name: pstrf(Tensor self, bool upper, Scalar tol)
self: not_implemented("pstrf")
@ -546,7 +553,7 @@
self: sum_backward(grad, self.sizes(), dim, keepdim)
- name: svd(Tensor self, bool some)
self: not_implemented("svd")
self: svd_backward(grads, self, some, res1, res2, res3)
- name: symeig(Tensor self, bool eigenvectors, bool upper)
self: not_implemented("symeig")

View File

@ -54,7 +54,9 @@ UNPACK_SELF = "auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;"
# to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
SUPPORTED_RETURN_TYPES = {
'Tensor', 'std::tuple<Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor>', 'std::vector<Tensor>',
'std::tuple<Tensor,Tensor,Tensor>',
'std::tuple<Tensor,Tensor,Tensor,Tensor>',
'std::vector<Tensor>',
'Scalar', 'bool', 'int64_t', 'void*'
}

View File

@ -1,5 +1,6 @@
#include "Functions.h"
#include <ATen/WrapDimUtils.h>
#include <iostream>
// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
@ -502,6 +503,99 @@ std::tuple<Tensor, Tensor, Tensor> prelu_double_backward(
}
}
// https://j-towns.github.io/papers/svd-derivative.pdf
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, const Tensor& raw_u, const Tensor& sigma, const Tensor& raw_v) {
auto m = self.size(0);
auto n = self.size(1);
auto k = sigma.size(0);
Tensor u, v;
if (!some) {
// ignore the free subspace
u = raw_u.narrow(1, 0, k);
v = raw_v.narrow(1, 0, k);
} else {
u = raw_u;
v = raw_v;
}
auto gu = grads[0];
auto gsigma = grads[1];
auto gv = grads[2];
auto im = self.type().eye(m);
auto in = self.type().eye(n);
auto ut = u.t();
auto vt = v.t();
auto sigma_mat = sigma.diag();
auto sigma_mat_inv = sigma.pow(-1).diag();
auto sigma_expanded_sq = sigma.pow(2).expand_as(sigma_mat);
auto F = (sigma_expanded_sq - sigma_expanded_sq.t()).pow(-1);
auto& long_type = sigma.type().toScalarType(at::kLong);
auto diag_indices = long_type.arange(0, F.numel(), k + 1);
F.view({-1}).index_fill_(0, diag_indices, 0);
Tensor u_term, sigma_term, v_term;
if (gu.defined()) {
u_term = u.mm(F.mul(ut.mm(gu) - gu.t().mm(u))).mm(sigma_mat);
if (m > k) {
u_term = u_term + (im - u.mm(ut)).mm(gu).mm(sigma_mat_inv);
}
u_term = u_term.mm(vt);
} else {
u_term = self.type().zeros({1}).expand_as(self);
}
if (gsigma.defined()) {
sigma_term = u.mm(gsigma.diag()).mm(vt);
} else {
sigma_term = self.type().zeros({1}).expand_as(self);
}
if (gv.defined()) {
auto gvt = gv.t();
v_term = sigma_mat.mm(F.mul(vt.mm(gv) - gvt.mm(v))).mm(vt);
if (n > k) {
v_term = v_term + sigma_mat_inv.mm(gvt.mm(in - v.mm(vt)));
}
v_term = u.mm(v_term);
} else {
v_term = self.type().zeros({1}).expand_as(self);
}
return u_term + sigma_term + v_term;
}
// Formula:
// d det / d A_ij = \sum_k (\prod_{l neq k} Sigma_l) U_ik V_jk
// that is, if det != 0
// d det / d A = U * (Sigma / det) * V^T
Tensor _det_with_svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
const Tensor& det, const Tensor& u, const Tensor& sigma, const Tensor& v) {
std::vector<torch::autograd::Variable> svd_grads(grads.begin() + 1, grads.end());
auto svd_term = svd_backward(svd_grads, self, true, u, sigma, v);
auto det_grad = grads[0];
auto size = self.size(0);
auto null_dim = size - sigma.nonzero().size(0);
if (null_dim >= 2) {
// \prod_{l neq k} Sigma_l is zero every where
return svd_term;
}
if (null_dim == 1) {
// only last sigma is 0
// \prod_{l neq k} Sigma_l is zero at all but last dim
// at last dim, it is:
auto scale = sigma.narrow(0, 0, size - 1).prod();
auto last_u = u.narrow(1, size - 1, 1);
auto last_v = v.narrow(1, size - 1, 1);
return svd_term + last_u.mm(last_v.transpose(0, 1)).mul_(scale.mul_(det_grad));
}
// no zero singular values
return svd_term + u.mm(sigma.pow(-1).mul_(det.mul(det_grad)).diag()).mm(v.transpose(0, 1));
}
}
${autograd_function_definitions}

View File

@ -193,6 +193,15 @@ VariableType::as_variable(std::tuple<Tensor, Tensor, Tensor> tensors) const {
make_variable(std::move(std::get<2>(tensors))));
}
std::tuple<Variable, Variable, Variable, Variable>
VariableType::as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensors) const {
return std::make_tuple<>(
make_variable(std::move(std::get<0>(tensors))),
make_variable(std::move(std::get<1>(tensors))),
make_variable(std::move(std::get<2>(tensors))),
make_variable(std::move(std::get<3>(tensors))));
}
std::vector<Variable> VariableType::as_variable(TensorList tl) const {
std::vector<Variable> variables;
for (auto& t : tl) {

View File

@ -54,6 +54,7 @@ private:
Variable as_variable(Tensor tensor) const;
std::tuple<Variable, Variable> as_variable(std::tuple<Tensor, Tensor> tensor) const;
std::tuple<Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor> tensor) const;
std::tuple<Variable, Variable, Variable, Variable> as_variable(std::tuple<Tensor, Tensor, Tensor, Tensor> tensor) const;
std::vector<Variable> as_variable(TensorList tensor) const;
Variable maybe_wrap(Tensor data, const Variable & self, bool inplace) const;

View File

@ -90,6 +90,12 @@ void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor>
outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));
}
void pack_list(list_of_retainable & outputs, std::tuple<Tensor, Tensor, Tensor, Tensor> v) {
outputs.push_back(toRetainableSteal(std::move(std::get<0>(v))));
outputs.push_back(toRetainableSteal(std::move(std::get<1>(v))));
outputs.push_back(toRetainableSteal(std::move(std::get<2>(v))));
outputs.push_back(toRetainableSteal(std::move(std::get<3>(v))));
}
// A list of functions taking TensorList arguments (where we can't use
// the number of inputs to choose an overload).

View File

@ -4276,11 +4276,12 @@ svd(input, some=True, out=None) -> (Tensor, Tensor, Tensor)
`U, S, V = torch.svd(A)` returns the singular value decomposition of a
real matrix `A` of size `(n x m)` such that :math:`A = USV'*`.
`U` is of shape `n x n`
`U` is of shape `n x min(n, m)`
`S` is of shape `n x m`
`S` is a diagonal square matrix of shape `min(n, m) x min(n, m)`, represented as
a vector of shape `(min(n, m),)` containing its diagonal entries.
`V` is of shape `m x m`.
`V` is of shape `m x min(n, m)`.
:attr:`some` represents the number of singular values to be computed.
If `some=True`, it computes some and `some=False` computes all.
@ -4288,6 +4289,16 @@ If `some=True`, it computes some and `some=False` computes all.
.. note:: Irrespective of the original strides, the returned matrix `U`
will be transposed, i.e. with strides `(1, n)` instead of `(n, 1)`.
.. note:: Extra care needs to be taken when backward through `U` and `V`
outputs. Such operation is really only stable when :attr:`input` is
full rank with all distinct singular values. Otherwise, `NaN` can
appear as the gradients are not properly defined. Also, when
:attr:`some` = `False`, the gradients on `U[:, min(n, m):]` and
`V[:, min(n, m):]` will be ignored as those vectors can be arbitrary
bases of the subspaces.
.. note:: Double backward through :meth:`~torch.svd` is not supported currently.
Args:
input (Tensor): the input 2D Tensor
some (bool, optional): controls the number of singular values to be computed

View File

@ -33,6 +33,16 @@ inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor> tensors) {
return r.release();
}
inline PyObject* wrap(std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> tensors) {
auto r = THPObjectPtr{PyTuple_New(4)};
if (!r) throw python_error();
PyTuple_SET_ITEM(r.get(), 0, wrap(std::move(std::get<0>(tensors))));
PyTuple_SET_ITEM(r.get(), 1, wrap(std::move(std::get<1>(tensors))));
PyTuple_SET_ITEM(r.get(), 2, wrap(std::move(std::get<2>(tensors))));
PyTuple_SET_ITEM(r.get(), 3, wrap(std::move(std::get<3>(tensors))));
return r.release();
}
inline PyObject* wrap(at::TensorList tl) {
auto r = THPObjectPtr{PyTuple_New(tl.size())};
if (!r) throw python_error();

View File

@ -2555,6 +2555,7 @@ static const char *R = &__R;
- Double
backends:
- CPU
- CUDA
variants:
- method
- function

View File

@ -4,7 +4,7 @@ from operator import mul
from functools import reduce
__all__ = [
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul',
'split', 'chunk', 'stack', 'unbind', 'btriunpack', 'matmul', 'det',
]
@ -245,3 +245,20 @@ def matmul(tensor1, tensor2, out=None):
raise ValueError("both arguments to __matmul__ need to be at least 1D, "
"but they are {}D and {}D".format(dim_tensor1, dim_tensor2))
def det(var):
"""Calculates determinant of a 2D square Variable.
.. note::
Backward through `det` internally uses SVD results. So double backward
through `det` will need to backward through :meth:`~Tensor.svd`. This
can be unstable in certain cases. Please see :meth:`~torch.svd` for
details.
Arguments:
var (Variable): The input 2D square Variable.
"""
if torch.is_tensor(var):
raise ValueError("det is currently only supported on Variable")
return var.det()