Support gpu triangle solve (#6648)

* add cuda trtrs

* remove queue

* add test trtrs
This commit is contained in:
Du Phan 2018-04-17 21:33:39 +09:00 committed by Adam Paszke
parent b34ae77be8
commit c345212c86
5 changed files with 55 additions and 4 deletions

View File

@ -3236,6 +3236,7 @@
- Double
backends:
- CPU
- CUDA
variants:
- method
- function

View File

@ -98,6 +98,43 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T
#endif
}
THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_,
const char *uplo, const char *trans, const char *diag)
{
#ifdef USE_MAGMA
THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional");
THArgCheck(b_->nDimension == 2, 2, "b should be 2 dimensional");
THArgCheck(a_->size[0] == a_->size[1], 1, "A should be square");
THArgCheck(b_->size[0] == a_->size[0], 2, "A,b size incompatible");
magma_side_t sz = MagmaLeft;
magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
magma_trans_t ts = trans[0] == 'N' ? MagmaNoTrans : MagmaTrans;
magma_diag_t dg = diag[0] == 'U' ? MagmaUnit : MagmaNonUnit;
real alpha = 1;
int64_t n = a_->size[0];
int64_t nrhs = b_->size[1];
THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_);
real *a_data = THCTensor_(data)(state, a);
real *b_data = THCTensor_(data)(state, b);
#if defined(THC_REAL_IS_FLOAT)
magma_strsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n);
#else
magma_dtrsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n);
#endif
THCTensor_(freeCopyTo)(state, a, ra_);
THCTensor_(freeCopyTo)(state, b, rb_);
#else
THError(NoMagma(trtrs));
#endif
}
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_)
{
#ifdef USE_MAGMA

View File

@ -6,6 +6,8 @@
// MAGMA (i.e. CUDA implementation of LAPACK functions)
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_,
const char *uplo, const char *trans, const char *diag);
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo);
THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr);

View File

@ -1558,6 +1558,10 @@ class TestCuda(TestCase):
def test_diagflat(self):
TestTorch._test_diagflat(self, dtype=torch.float32, device='cuda')
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
def test_trtrs(self):
TestTorch._test_trtrs(self, lambda t: t.cuda())
@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
def test_get_set_rng_state_all(self):
states = torch.cuda.get_rng_state_all()

View File

@ -3095,8 +3095,8 @@ class TestTorch(TestCase):
res2 = torch.ormqr(m, tau, mat2, False, True)
self.assertEqual(res1, res2)
@skipIfNoLapack
def test_trtrs(self):
@staticmethod
def _test_trtrs(self, cast):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
@ -3106,6 +3106,9 @@ class TestTorch(TestCase):
(-1.56, 4.00, -8.67, 1.75, 2.86),
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
a = cast(a)
b = cast(b)
U = torch.triu(a)
L = torch.tril(a)
@ -3143,14 +3146,18 @@ class TestTorch(TestCase):
# test reuse
res1 = torch.trtrs(b, a)[0]
ta = torch.Tensor()
tb = torch.Tensor()
ta = cast(torch.Tensor())
tb = cast(torch.Tensor())
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
torch.trtrs(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
@skipIfNoLapack
def test_trtrs(self):
self._test_trtrs(self, lambda t: t)
@skipIfNoLapack
def test_gels(self):
def _test_underdetermined(a, b, expectedNorm):