mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Support gpu triangle solve (#6648)
* add cuda trtrs * remove queue * add test trtrs
This commit is contained in:
parent
b34ae77be8
commit
c345212c86
|
|
@ -3236,6 +3236,7 @@
|
|||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- method
|
||||
- function
|
||||
|
|
|
|||
|
|
@ -98,6 +98,43 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T
|
|||
#endif
|
||||
}
|
||||
|
||||
THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_,
|
||||
const char *uplo, const char *trans, const char *diag)
|
||||
{
|
||||
#ifdef USE_MAGMA
|
||||
THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional");
|
||||
THArgCheck(b_->nDimension == 2, 2, "b should be 2 dimensional");
|
||||
THArgCheck(a_->size[0] == a_->size[1], 1, "A should be square");
|
||||
THArgCheck(b_->size[0] == a_->size[0], 2, "A,b size incompatible");
|
||||
|
||||
magma_side_t sz = MagmaLeft;
|
||||
magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower;
|
||||
magma_trans_t ts = trans[0] == 'N' ? MagmaNoTrans : MagmaTrans;
|
||||
magma_diag_t dg = diag[0] == 'U' ? MagmaUnit : MagmaNonUnit;
|
||||
|
||||
real alpha = 1;
|
||||
|
||||
int64_t n = a_->size[0];
|
||||
int64_t nrhs = b_->size[1];
|
||||
|
||||
THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_);
|
||||
THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_);
|
||||
real *a_data = THCTensor_(data)(state, a);
|
||||
real *b_data = THCTensor_(data)(state, b);
|
||||
|
||||
#if defined(THC_REAL_IS_FLOAT)
|
||||
magma_strsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n);
|
||||
#else
|
||||
magma_dtrsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n);
|
||||
#endif
|
||||
|
||||
THCTensor_(freeCopyTo)(state, a, ra_);
|
||||
THCTensor_(freeCopyTo)(state, b, rb_);
|
||||
#else
|
||||
THError(NoMagma(trtrs));
|
||||
#endif
|
||||
}
|
||||
|
||||
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_)
|
||||
{
|
||||
#ifdef USE_MAGMA
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@
|
|||
|
||||
// MAGMA (i.e. CUDA implementation of LAPACK functions)
|
||||
THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
|
||||
THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_,
|
||||
const char *uplo, const char *trans, const char *diag);
|
||||
THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_);
|
||||
THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo);
|
||||
THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr);
|
||||
|
|
|
|||
|
|
@ -1558,6 +1558,10 @@ class TestCuda(TestCase):
|
|||
def test_diagflat(self):
|
||||
TestTorch._test_diagflat(self, dtype=torch.float32, device='cuda')
|
||||
|
||||
@unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected")
|
||||
def test_trtrs(self):
|
||||
TestTorch._test_trtrs(self, lambda t: t.cuda())
|
||||
|
||||
@unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected")
|
||||
def test_get_set_rng_state_all(self):
|
||||
states = torch.cuda.get_rng_state_all()
|
||||
|
|
|
|||
|
|
@ -3095,8 +3095,8 @@ class TestTorch(TestCase):
|
|||
res2 = torch.ormqr(m, tau, mat2, False, True)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_trtrs(self):
|
||||
@staticmethod
|
||||
def _test_trtrs(self, cast):
|
||||
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
|
||||
(-6.05, -3.30, 5.36, -4.44, 1.08),
|
||||
(-0.45, 2.58, -2.70, 0.27, 9.04),
|
||||
|
|
@ -3106,6 +3106,9 @@ class TestTorch(TestCase):
|
|||
(-1.56, 4.00, -8.67, 1.75, 2.86),
|
||||
(9.81, -4.09, -4.57, -8.61, 8.99))).t()
|
||||
|
||||
a = cast(a)
|
||||
b = cast(b)
|
||||
|
||||
U = torch.triu(a)
|
||||
L = torch.tril(a)
|
||||
|
||||
|
|
@ -3143,14 +3146,18 @@ class TestTorch(TestCase):
|
|||
|
||||
# test reuse
|
||||
res1 = torch.trtrs(b, a)[0]
|
||||
ta = torch.Tensor()
|
||||
tb = torch.Tensor()
|
||||
ta = cast(torch.Tensor())
|
||||
tb = cast(torch.Tensor())
|
||||
torch.trtrs(b, a, out=(tb, ta))
|
||||
self.assertEqual(res1, tb, 0)
|
||||
tb.zero_()
|
||||
torch.trtrs(b, a, out=(tb, ta))
|
||||
self.assertEqual(res1, tb, 0)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_trtrs(self):
|
||||
self._test_trtrs(self, lambda t: t)
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_gels(self):
|
||||
def _test_underdetermined(a, b, expectedNorm):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user