mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] fix inverse bug for N>1024 (#146754)
Fixes #138200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146754 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
60a45eb862
commit
cfea55dbec
|
|
@ -1,61 +0,0 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/linalg_inv_ex.h>
|
||||
#include <ATen/ops/linalg_inv_ex_native.h>
|
||||
#endif
|
||||
|
||||
namespace at::native {
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
|
||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
||||
auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt);
|
||||
auto cpu_result = result.to("cpu");
|
||||
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
|
||||
info.copy_(cpu_info);
|
||||
result.copy_(cpu_result);
|
||||
return;
|
||||
}
|
||||
|
||||
using namespace mps;
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
info.zero_();
|
||||
|
||||
if (A.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.is_contiguous()) {
|
||||
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
@autoreleasepool {
|
||||
string key = "inv_out_mps" + getTensorsStringKey({A});
|
||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
|
||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
newCachedGraph->outputTensor_ = outputTensor;
|
||||
});
|
||||
|
||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
|
||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
||||
|
||||
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
|
||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/mps/MPSProfiler.h>
|
||||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
// For MTLLanguageVersion_3_1
|
||||
|
|
@ -22,6 +23,7 @@
|
|||
#include <ATen/ops/cholesky_native.h>
|
||||
#include <ATen/ops/linalg_cholesky_ex_native.h>
|
||||
#include <ATen/ops/linalg_cholesky_native.h>
|
||||
#include <ATen/ops/linalg_inv_ex_native.h>
|
||||
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
||||
#include <ATen/ops/linalg_lu_factor_native.h>
|
||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
|
|
@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
|||
}
|
||||
}
|
||||
|
||||
static void linalg_solve_out_mps_impl(const at::Tensor& A,
|
||||
const at::Tensor& B,
|
||||
static void linalg_solve_out_mps_impl(const Tensor& A,
|
||||
const Tensor& B,
|
||||
bool left,
|
||||
bool check_errors,
|
||||
const at::Tensor& result,
|
||||
const at::Tensor& LU,
|
||||
const at::Tensor& pivots,
|
||||
const at::Tensor& info) {
|
||||
const Tensor& result,
|
||||
const Tensor& LU,
|
||||
const Tensor& pivots,
|
||||
const Tensor& info) {
|
||||
using namespace mps;
|
||||
|
||||
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||
|
|
@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A,
|
|||
}
|
||||
}
|
||||
|
||||
static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||
using namespace mps;
|
||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
|
||||
using CachedGraph = MPSUnaryCachedGraph;
|
||||
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
info.zero_();
|
||||
|
||||
if (A.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!result.is_contiguous()) {
|
||||
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||
}
|
||||
auto A_sizes = A.sizes();
|
||||
int ndim = A.dim();
|
||||
|
||||
Tensor LU = empty_like(A);
|
||||
Tensor identity = zeros_like(A);
|
||||
Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt));
|
||||
(ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1);
|
||||
linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info);
|
||||
}
|
||||
|
||||
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||
using namespace mps;
|
||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
|
|
@ -1427,4 +1455,8 @@ TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps)
|
|||
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
|
||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||
mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info);
|
||||
}
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -7815,18 +7815,19 @@ class TestMPS(TestCaseMPS):
|
|||
|
||||
# Test inverse
|
||||
def test_inverse(self):
|
||||
def helper(n):
|
||||
def helper(n, atol=1e-5, rtol=1e-6):
|
||||
cpu_input = torch.randn(n, n, device='cpu')
|
||||
mps_input = cpu_input.to('mps')
|
||||
|
||||
cpu_result = torch.linalg.inv(cpu_input)
|
||||
mps_result = torch.linalg.inv(mps_input)
|
||||
self.assertEqual(cpu_result, mps_result)
|
||||
self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol)
|
||||
|
||||
helper(2)
|
||||
helper(6)
|
||||
helper(3)
|
||||
helper(8)
|
||||
helper(1025, atol=1e-4)
|
||||
|
||||
# Test tril
|
||||
def test_tril(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user