diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm deleted file mode 100644 index 5574df89afe..00000000000 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ /dev/null @@ -1,61 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - -namespace at::native { - -TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { - TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); - TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { - TORCH_WARN_ONCE( - "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); - auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt); - auto cpu_result = result.to("cpu"); - at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); - info.copy_(cpu_info); - result.copy_(cpu_result); - return; - } - - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - MPSStream* stream = getCurrentMPSStream(); - info.zero_(); - - if (A.numel() == 0) { - return; - } - - if (!result.is_contiguous()) { - result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); - } - - @autoreleasepool { - string key = "inv_out_mps" + getTensorsStringKey({A}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A); - MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - auto feeds = dictionaryFromPlaceholders(inputPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 22aee2307f6..1a9e841cfbc 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -2,6 +2,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include // For MTLLanguageVersion_3_1 @@ -22,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A, } } -static void linalg_solve_out_mps_impl(const at::Tensor& A, - const at::Tensor& B, +static void linalg_solve_out_mps_impl(const Tensor& A, + const Tensor& B, bool left, bool check_errors, - const at::Tensor& result, - const at::Tensor& LU, - const at::Tensor& pivots, - const at::Tensor& info) { + const Tensor& result, + const Tensor& LU, + const Tensor& pivots, + const Tensor& info) { using namespace mps; TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), @@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A, } } +static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + using namespace mps; + TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); + TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); + using CachedGraph = MPSUnaryCachedGraph; + + MPSStream* stream = getCurrentMPSStream(); + info.zero_(); + + if (A.numel() == 0) { + return; + } + + if (!result.is_contiguous()) { + result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); + } + auto A_sizes = A.sizes(); + int ndim = A.dim(); + + Tensor LU = empty_like(A); + Tensor identity = zeros_like(A); + Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt)); + (ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1); + linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info); +} + static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { using namespace mps; static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); @@ -1427,4 +1455,8 @@ TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps) (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); } + +TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { + mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info); +} } // namespace at::native diff --git a/test/test_mps.py b/test/test_mps.py index a950ae28b76..61903bd3900 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -7815,18 +7815,19 @@ class TestMPS(TestCaseMPS): # Test inverse def test_inverse(self): - def helper(n): + def helper(n, atol=1e-5, rtol=1e-6): cpu_input = torch.randn(n, n, device='cpu') mps_input = cpu_input.to('mps') cpu_result = torch.linalg.inv(cpu_input) mps_result = torch.linalg.inv(mps_input) - self.assertEqual(cpu_result, mps_result) + self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol) helper(2) helper(6) helper(3) helper(8) + helper(1025, atol=1e-4) # Test tril def test_tril(self):