mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] fix inverse bug for N>1024 (#146754)
Fixes #138200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/146754 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
parent
60a45eb862
commit
cfea55dbec
|
|
@ -1,61 +0,0 @@
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
||||||
#include <ATen/native/mps/MPSGraphVenturaOps.h>
|
|
||||||
#include <ATen/native/mps/OperationUtils.h>
|
|
||||||
|
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
|
||||||
#include <ATen/Functions.h>
|
|
||||||
#include <ATen/NativeFunctions.h>
|
|
||||||
#else
|
|
||||||
#include <ATen/ops/linalg_inv_ex.h>
|
|
||||||
#include <ATen/ops/linalg_inv_ex_native.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at::native {
|
|
||||||
|
|
||||||
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
|
||||||
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
|
||||||
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
|
|
||||||
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
|
|
||||||
TORCH_WARN_ONCE(
|
|
||||||
"torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
|
|
||||||
auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt);
|
|
||||||
auto cpu_result = result.to("cpu");
|
|
||||||
at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
|
|
||||||
info.copy_(cpu_info);
|
|
||||||
result.copy_(cpu_result);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
using namespace mps;
|
|
||||||
using CachedGraph = MPSUnaryCachedGraph;
|
|
||||||
|
|
||||||
MPSStream* stream = getCurrentMPSStream();
|
|
||||||
info.zero_();
|
|
||||||
|
|
||||||
if (A.numel() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!result.is_contiguous()) {
|
|
||||||
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
|
||||||
}
|
|
||||||
|
|
||||||
@autoreleasepool {
|
|
||||||
string key = "inv_out_mps" + getTensorsStringKey({A});
|
|
||||||
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
|
|
||||||
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
|
|
||||||
MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
|
|
||||||
|
|
||||||
newCachedGraph->inputTensor_ = inputTensor;
|
|
||||||
newCachedGraph->outputTensor_ = outputTensor;
|
|
||||||
});
|
|
||||||
|
|
||||||
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
|
|
||||||
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
|
|
||||||
|
|
||||||
auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
|
|
||||||
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace at::native
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||||
#include <ATen/mps/MPSProfiler.h>
|
#include <ATen/mps/MPSProfiler.h>
|
||||||
|
#include <ATen/native/LinearAlgebra.h>
|
||||||
#include <ATen/native/LinearAlgebraUtils.h>
|
#include <ATen/native/LinearAlgebraUtils.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
// For MTLLanguageVersion_3_1
|
// For MTLLanguageVersion_3_1
|
||||||
|
|
@ -22,6 +23,7 @@
|
||||||
#include <ATen/ops/cholesky_native.h>
|
#include <ATen/ops/cholesky_native.h>
|
||||||
#include <ATen/ops/linalg_cholesky_ex_native.h>
|
#include <ATen/ops/linalg_cholesky_ex_native.h>
|
||||||
#include <ATen/ops/linalg_cholesky_native.h>
|
#include <ATen/ops/linalg_cholesky_native.h>
|
||||||
|
#include <ATen/ops/linalg_inv_ex_native.h>
|
||||||
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
#include <ATen/ops/linalg_lu_factor_ex_native.h>
|
||||||
#include <ATen/ops/linalg_lu_factor_native.h>
|
#include <ATen/ops/linalg_lu_factor_native.h>
|
||||||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||||
|
|
@ -261,14 +263,14 @@ static void linalg_lu_factor_ex_out_mps_impl(const Tensor& A,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void linalg_solve_out_mps_impl(const at::Tensor& A,
|
static void linalg_solve_out_mps_impl(const Tensor& A,
|
||||||
const at::Tensor& B,
|
const Tensor& B,
|
||||||
bool left,
|
bool left,
|
||||||
bool check_errors,
|
bool check_errors,
|
||||||
const at::Tensor& result,
|
const Tensor& result,
|
||||||
const at::Tensor& LU,
|
const Tensor& LU,
|
||||||
const at::Tensor& pivots,
|
const Tensor& pivots,
|
||||||
const at::Tensor& info) {
|
const Tensor& info) {
|
||||||
using namespace mps;
|
using namespace mps;
|
||||||
|
|
||||||
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
TORCH_CHECK(!c10::isComplexType(A.scalar_type()) && !c10::isComplexType(LU.scalar_type()),
|
||||||
|
|
@ -436,6 +438,32 @@ static void linalg_solve_out_mps_impl(const at::Tensor& A,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||||
|
using namespace mps;
|
||||||
|
TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
|
||||||
|
TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
|
||||||
|
using CachedGraph = MPSUnaryCachedGraph;
|
||||||
|
|
||||||
|
MPSStream* stream = getCurrentMPSStream();
|
||||||
|
info.zero_();
|
||||||
|
|
||||||
|
if (A.numel() == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result.is_contiguous()) {
|
||||||
|
result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
|
||||||
|
}
|
||||||
|
auto A_sizes = A.sizes();
|
||||||
|
int ndim = A.dim();
|
||||||
|
|
||||||
|
Tensor LU = empty_like(A);
|
||||||
|
Tensor identity = zeros_like(A);
|
||||||
|
Tensor pivots = empty({A_sizes.begin(), A_sizes.end() - 1}, A.options().dtype(kInt));
|
||||||
|
(ndim == 2 ? identity.diagonal() : identity.diagonal(0, -2, -1)).fill_(1);
|
||||||
|
linalg_solve_out_mps_impl(A, identity, true, check_errors, result, LU, pivots, info);
|
||||||
|
}
|
||||||
|
|
||||||
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
static Tensor& mm_out_mps_impl(const Tensor& self, const Tensor& other, Tensor& output) {
|
||||||
using namespace mps;
|
using namespace mps;
|
||||||
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||||
|
|
@ -1427,4 +1455,8 @@ TORCH_IMPL_FUNC(linalg_lu_factor_ex_out_mps)
|
||||||
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
|
(const Tensor& A, bool pivot, bool check_errors, const Tensor& LU, const Tensor& pivots, const Tensor& info) {
|
||||||
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);
|
mps::linalg_lu_factor_ex_out_mps_impl(A, pivot, LU, pivots, info, check_errors);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
|
||||||
|
mps::linalg_inv_ex_out_mps_impl(A, check_errors, result, info);
|
||||||
|
}
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -7815,18 +7815,19 @@ class TestMPS(TestCaseMPS):
|
||||||
|
|
||||||
# Test inverse
|
# Test inverse
|
||||||
def test_inverse(self):
|
def test_inverse(self):
|
||||||
def helper(n):
|
def helper(n, atol=1e-5, rtol=1e-6):
|
||||||
cpu_input = torch.randn(n, n, device='cpu')
|
cpu_input = torch.randn(n, n, device='cpu')
|
||||||
mps_input = cpu_input.to('mps')
|
mps_input = cpu_input.to('mps')
|
||||||
|
|
||||||
cpu_result = torch.linalg.inv(cpu_input)
|
cpu_result = torch.linalg.inv(cpu_input)
|
||||||
mps_result = torch.linalg.inv(mps_input)
|
mps_result = torch.linalg.inv(mps_input)
|
||||||
self.assertEqual(cpu_result, mps_result)
|
self.assertEqual(cpu_result, mps_result, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
helper(2)
|
helper(2)
|
||||||
helper(6)
|
helper(6)
|
||||||
helper(3)
|
helper(3)
|
||||||
helper(8)
|
helper(8)
|
||||||
|
helper(1025, atol=1e-4)
|
||||||
|
|
||||||
# Test tril
|
# Test tril
|
||||||
def test_tril(self):
|
def test_tril(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user