[MPS] fix inverse bug for N>1024 (#146754)

Fixes #138200

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146754
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Isalia20 2025-04-05 21:49:21 +00:00 committed by PyTorch MergeBot
parent 60a45eb862
commit cfea55dbec
3 changed files with 41 additions and 69 deletions

View File

@ -1,61 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/OperationUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/linalg_inv_ex.h>
#include <ATen/ops/linalg_inv_ex_native.h>
#endif
namespace at::native {
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
TORCH_WARN_ONCE(
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt);
auto cpu_result = result.to("cpu");
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
info.copy_(cpu_info);
result.copy_(cpu_result);
return;
}
using namespace mps;
using CachedGraph = MPSUnaryCachedGraph;
MPSStream* stream = getCurrentMPSStream();
info.zero_();
if (A.numel() == 0) {
return;
}
if (!result.is_contiguous()) {
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
@autoreleasepool {
string key = "inv_out_mps" + getTensorsStringKey({A});
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
}
} // namespace at::native

View File

@ -2,6 +2,7 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/mps/MPSProfiler.h> #include <ATen/mps/MPSProfiler.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/LinearAlgebraUtils.h> #include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h> #include <ATen/native/Resize.h>
// For MTLLanguageVersion_3_1 // For MTLLanguageVersion_3_1
@ -22,6 +23,7 @@
#include <ATen/ops/cholesky_native.h> #include <ATen/ops/cholesky_native.h>
#include <ATen/ops/linalg_cholesky_ex_native.h> #include <ATen/ops/linalg_cholesky_ex_native.h>
#include <ATen/ops/linalg_cholesky_native.h> #include <ATen/ops/linalg_cholesky_native.h>
#include <ATen/ops/linalg_inv_ex_native.h>
#include <ATen/ops/linalg_lu_factor_ex_native.h> #include <ATen/ops/linalg_lu_factor_ex_native.h>
#include <ATen/ops/linalg_lu_factor_native.h> #include <ATen/ops/linalg_lu_factor_native.h>
#include <ATen/ops/linalg_solve_triangular_native.h> #include <ATen/ops/linalg_solve_triangular_native.h>
@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
} }
} }
static void linalg_solve_out_mps_impl(const at::Tensor& A, static void linalg_solve_out_mps_impl(const Tensor& A,
const at::Tensor& B, const Tensor& B,
bool left, bool left,
bool check_errors, bool check_errors,
const at::Tensor& result, const Tensor& result,
const at::Tensor& LU, const Tensor& LU,
const at::Tensor& pivots, const Tensor& pivots,
const at::Tensor& info) { const Tensor& info) {
using namespace mps; using namespace mps;
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()), TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A,
} }
} }
static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
using namespace mps;
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
using CachedGraph = MPSUnaryCachedGraph;
MPSStream* stream = getCurrentMPSStream();
info.zero_();
if (A.numel() == 0) {
return;
}
if (!result.is_contiguous()) {
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
auto A_sizes = A.sizes();
int ndim = A.dim();
Tensor LU = empty_like(A);
Tensor identity = zeros_like(A);
Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt));
(ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1);
linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info);
}
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) { static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
using namespace mps; using namespace mps;
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
@ -1427,4 +1455,8 @@ TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps)
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) { (const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors); mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);
} }
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info);
}
} // namespace at::native } // namespace at::native

View File

@ -7815,18 +7815,19 @@ class TestMPS(TestCaseMPS):
# Test inverse # Test inverse
def test_inverse(self): def test_inverse(self):
def helper(n): def helper(n, atol=1e-5, rtol=1e-6):
cpu_input = torch.randn(n, n, device='cpu') cpu_input = torch.randn(n, n, device='cpu')
mps_input = cpu_input.to('mps') mps_input = cpu_input.to('mps')
cpu_result = torch.linalg.inv(cpu_input) cpu_result = torch.linalg.inv(cpu_input)
mps_result = torch.linalg.inv(mps_input) mps_result = torch.linalg.inv(mps_input)
self.assertEqual(cpu_result, mps_result) self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol)
helper(2) helper(2)
helper(6) helper(6)
helper(3) helper(3)
helper(8) helper(8)
helper(1025, atol=1e-4)
# Test tril # Test tril
def test_tril(self): def test_tril(self):