diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 3666b77e425..bdeaffd9050 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -3236,6 +3236,7 @@ - Double backends: - CPU + - CUDA variants: - method - function diff --git a/aten/src/THC/generic/THCTensorMathMagma.cu b/aten/src/THC/generic/THCTensorMathMagma.cu index 1eb39857dee..374b877ceaf 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.cu +++ b/aten/src/THC/generic/THCTensorMathMagma.cu @@ -98,6 +98,43 @@ THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, T #endif } +THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_, + const char *uplo, const char *trans, const char *diag) +{ +#ifdef USE_MAGMA + THArgCheck(a_->nDimension == 2, 1, "A should be 2 dimensional"); + THArgCheck(b_->nDimension == 2, 2, "b should be 2 dimensional"); + THArgCheck(a_->size[0] == a_->size[1], 1, "A should be square"); + THArgCheck(b_->size[0] == a_->size[0], 2, "A,b size incompatible"); + + magma_side_t sz = MagmaLeft; + magma_uplo_t ul = uplo[0] == 'U' ? MagmaUpper : MagmaLower; + magma_trans_t ts = trans[0] == 'N' ? MagmaNoTrans : MagmaTrans; + magma_diag_t dg = diag[0] == 'U' ? MagmaUnit : MagmaNonUnit; + + real alpha = 1; + + int64_t n = a_->size[0]; + int64_t nrhs = b_->size[1]; + + THCTensor *a = THCTensor_(newColumnMajor)(state, ra_, a_); + THCTensor *b = THCTensor_(newColumnMajor)(state, rb_, b_); + real *a_data = THCTensor_(data)(state, a); + real *b_data = THCTensor_(data)(state, b); + +#if defined(THC_REAL_IS_FLOAT) + magma_strsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n); +#else + magma_dtrsm(sz, ul, ts, dg, n, nrhs, alpha, a_data, n, b_data, n); +#endif + + THCTensor_(freeCopyTo)(state, a, ra_); + THCTensor_(freeCopyTo)(state, b, rb_); +#else + THError(NoMagma(trtrs)); +#endif +} + THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_) { #ifdef USE_MAGMA diff --git a/aten/src/THC/generic/THCTensorMathMagma.h b/aten/src/THC/generic/THCTensorMathMagma.h index 2aee308e0c3..1462af4ddaa 100644 --- a/aten/src/THC/generic/THCTensorMathMagma.h +++ b/aten/src/THC/generic/THCTensorMathMagma.h @@ -6,6 +6,8 @@ // MAGMA (i.e. CUDA implementation of LAPACK functions) THC_API void THCTensor_(gesv)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); +THC_API void THCTensor_(trtrs)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_, + const char *uplo, const char *trans, const char *diag); THC_API void THCTensor_(gels)(THCState *state, THCTensor *rb_, THCTensor *ra_, THCTensor *b_, THCTensor *a_); THC_API void THCTensor_(syev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobz, const char *uplo); THC_API void THCTensor_(geev)(THCState *state, THCTensor *re_, THCTensor *rv_, THCTensor *a_, const char *jobvr); diff --git a/test/test_cuda.py b/test/test_cuda.py index 1c424d4156c..95d4b1d080e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1558,6 +1558,10 @@ class TestCuda(TestCase): def test_diagflat(self): TestTorch._test_diagflat(self, dtype=torch.float32, device='cuda') + @unittest.skipIf(not HAS_MAGMA, "no MAGMA library detected") + def test_trtrs(self): + TestTorch._test_trtrs(self, lambda t: t.cuda()) + @unittest.skipIf(torch.cuda.device_count() < 2, "only one GPU detected") def test_get_set_rng_state_all(self): states = torch.cuda.get_rng_state_all() diff --git a/test/test_torch.py b/test/test_torch.py index 1db29d9b2e8..01a156633bd 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3095,8 +3095,8 @@ class TestTorch(TestCase): res2 = torch.ormqr(m, tau, mat2, False, True) self.assertEqual(res1, res2) - @skipIfNoLapack - def test_trtrs(self): + @staticmethod + def _test_trtrs(self, cast): a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23), (-6.05, -3.30, 5.36, -4.44, 1.08), (-0.45, 2.58, -2.70, 0.27, 9.04), @@ -3106,6 +3106,9 @@ class TestTorch(TestCase): (-1.56, 4.00, -8.67, 1.75, 2.86), (9.81, -4.09, -4.57, -8.61, 8.99))).t() + a = cast(a) + b = cast(b) + U = torch.triu(a) L = torch.tril(a) @@ -3143,14 +3146,18 @@ class TestTorch(TestCase): # test reuse res1 = torch.trtrs(b, a)[0] - ta = torch.Tensor() - tb = torch.Tensor() + ta = cast(torch.Tensor()) + tb = cast(torch.Tensor()) torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) tb.zero_() torch.trtrs(b, a, out=(tb, ta)) self.assertEqual(res1, tb, 0) + @skipIfNoLapack + def test_trtrs(self): + self._test_trtrs(self, lambda t: t) + @skipIfNoLapack def test_gels(self): def _test_underdetermined(a, b, expectedNorm):