mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Add linalg.householder_product for MPS (#166090)
Fixes #166089 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166090 Approved by: https://github.com/malfet
This commit is contained in:
parent
6038e476e8
commit
c9b49e506e
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
16
aten/src/ATen/native/mps/kernels/LinearAlgebra.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#pragma onces
|
||||
#include <c10/metal/common.h>
|
||||
|
||||
template <unsigned N = c10::metal::max_ndim>
|
||||
struct OrgqrParams {
|
||||
int32_t num_batch_dims;
|
||||
|
||||
uint32_t m;
|
||||
uint32_t n;
|
||||
uint32_t k;
|
||||
|
||||
::c10::metal::array<uint32_t, N> A_strides;
|
||||
::c10::metal::array<uint32_t, N> tau_strides;
|
||||
::c10::metal::array<uint32_t, N> H_strides;
|
||||
::c10::metal::array<uint32_t, N> H_sizes;
|
||||
};
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_simdgroup>
|
||||
|
|
@ -640,6 +641,164 @@ kernel void applyPivots(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T bool_to_float(bool b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 bool_to_float(bool b) {
|
||||
return half2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 bool_to_float(bool b) {
|
||||
return float2(b ? 1 : 0, 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static T calc_H_irc(
|
||||
device T* A,
|
||||
uint32_t A_stride_r,
|
||||
uint32_t A_stride_c,
|
||||
constant T* tau,
|
||||
uint32_t tau_stride,
|
||||
uint32_t r,
|
||||
uint32_t c,
|
||||
uint32_t i) {
|
||||
T I_val = bool_to_float<T>(r == c);
|
||||
T tau_val = tau[i * tau_stride];
|
||||
|
||||
T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]);
|
||||
T A_ri = A[r * A_stride_r + i * A_stride_c];
|
||||
|
||||
T c_eq_i = bool_to_float<T>(c == i);
|
||||
T r_eq_i = bool_to_float<T>(r == i);
|
||||
|
||||
T A_ci_ = (c > i) ? A_ci : c_eq_i;
|
||||
T A_ri_ = (r > i) ? A_ri : r_eq_i;
|
||||
|
||||
return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_));
|
||||
}
|
||||
|
||||
// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the
|
||||
// result of matrix multiplying A and B together. A and B must be size m-by-m
|
||||
// and have the same strides. The formula for this operation, written in Python
|
||||
// syntax, is:
|
||||
// (A @ B)[r, c] = A[r, :].dot(B[:, c])
|
||||
template <typename T>
|
||||
static T calc_matmul_rc(
|
||||
device T* A,
|
||||
device T* B,
|
||||
uint32_t stride_r,
|
||||
uint32_t stride_c,
|
||||
uint32_t m,
|
||||
uint32_t r,
|
||||
uint32_t c) {
|
||||
T AB_rc = 0;
|
||||
auto A_row_offset = r * stride_r;
|
||||
auto B_col_offset = c * stride_c;
|
||||
|
||||
uint32_t A_col_offset = 0;
|
||||
uint32_t B_row_offset = 0;
|
||||
|
||||
for (uint32_t j = 0; j < m;
|
||||
j++, A_col_offset += stride_c, B_row_offset += stride_r) {
|
||||
AB_rc += c10::metal::mul(
|
||||
A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]);
|
||||
}
|
||||
return AB_rc;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void orgqr(
|
||||
device T* A [[buffer(0)]],
|
||||
constant T* tau [[buffer(1)]],
|
||||
device T* H [[buffer(2)]],
|
||||
device T* H_prod [[buffer(3)]],
|
||||
constant OrgqrParams<>& params [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
constant auto& A_strides = params.A_strides;
|
||||
constant auto& tau_strides = params.tau_strides;
|
||||
constant auto& H_strides = params.H_strides;
|
||||
constant auto& H_sizes = params.H_sizes;
|
||||
|
||||
auto num_batch_dims = params.num_batch_dims;
|
||||
auto m = params.m;
|
||||
auto n = params.n;
|
||||
auto k = params.k;
|
||||
|
||||
auto m2 = m * m;
|
||||
auto batch_idx = tid / m2;
|
||||
|
||||
// Find the matrices for this thread's batch index
|
||||
uint32_t A_offset = 0;
|
||||
uint32_t tau_offset = 0;
|
||||
uint32_t H_offset = 0;
|
||||
|
||||
for (auto dim = num_batch_dims - 1; dim >= 0; dim--) {
|
||||
auto dim_size = H_sizes[dim];
|
||||
auto dim_idx = batch_idx % dim_size;
|
||||
|
||||
A_offset += dim_idx * A_strides[dim];
|
||||
tau_offset += dim_idx * tau_strides[dim];
|
||||
H_offset += dim_idx * H_strides[dim];
|
||||
|
||||
batch_idx /= dim_size;
|
||||
}
|
||||
|
||||
A += A_offset;
|
||||
tau += tau_offset;
|
||||
H += H_offset;
|
||||
H_prod += H_offset;
|
||||
|
||||
auto matrix_idx = tid % m2;
|
||||
auto r = matrix_idx / m;
|
||||
auto c = matrix_idx % m;
|
||||
auto A_stride_r = A_strides[num_batch_dims];
|
||||
auto A_stride_c = A_strides[num_batch_dims + 1];
|
||||
auto tau_stride = tau_strides[num_batch_dims];
|
||||
auto H_stride_r = H_strides[num_batch_dims];
|
||||
auto H_stride_c = H_strides[num_batch_dims + 1];
|
||||
|
||||
// Find the element of H and H_prod that this thread will calculate
|
||||
device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c);
|
||||
device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c);
|
||||
|
||||
for (uint32_t i = 0; i < k; i++) {
|
||||
// Calculate and write H_i
|
||||
|
||||
T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i);
|
||||
|
||||
// Calculate element [r, c] of prod(H_0, ..., H_i)
|
||||
if (i == 0) {
|
||||
*H_prod_elem_ptr = H_irc;
|
||||
} else {
|
||||
*H_elem_ptr = H_irc;
|
||||
|
||||
// Need this sync because the below matmul requires all threads to finish
|
||||
// writing their entries to `H_prod` and `H`.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
T H_prod_0_to_i_rc =
|
||||
calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c);
|
||||
|
||||
// Need this sync because the above matmul uses the current values in
|
||||
// `H_prod`, and we don't want to overwrite those until all threads are
|
||||
// finished using them.
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
*H_prod_elem_ptr = H_prod_0_to_i_rc;
|
||||
}
|
||||
}
|
||||
|
||||
device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c);
|
||||
|
||||
if (c < n) {
|
||||
*A_elem_ptr = *H_prod_elem_ptr;
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MM_OPS(DTYPE) \
|
||||
template [[host_name("matmul_" #DTYPE)]] kernel void matmul<DTYPE>( \
|
||||
constant DTYPE * mat1Data [[buffer(0)]], \
|
||||
|
|
@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int);
|
|||
INSTANTIATE_MM_OPS(short);
|
||||
INSTANTIATE_MM_OPS(char);
|
||||
INSTANTIATE_MM_OPS(uchar);
|
||||
|
||||
#define REGISTER_ORGQR(T) \
|
||||
template [[host_name("orgqr_" #T)]] \
|
||||
kernel void orgqr<T>( \
|
||||
device T * A [[buffer(0)]], \
|
||||
constant T * tau [[buffer(1)]], \
|
||||
device T * H [[buffer(2)]], \
|
||||
device T * H_prod [[buffer(3)]], \
|
||||
constant OrgqrParams<> & params [[buffer(4)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
REGISTER_ORGQR(float);
|
||||
REGISTER_ORGQR(half);
|
||||
REGISTER_ORGQR(bfloat);
|
||||
REGISTER_ORGQR(float2);
|
||||
REGISTER_ORGQR(half2);
|
||||
|
|
|
|||
|
|
@ -8,6 +8,9 @@
|
|||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mps/MPSGraphSequoiaOps.h>
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/kernels/LinearAlgebra.h>
|
||||
|
||||
#include <fmt/format.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -28,6 +31,7 @@
|
|||
#include <ATen/ops/linalg_solve_triangular_native.h>
|
||||
#include <ATen/ops/lu_unpack_native.h>
|
||||
#include <ATen/ops/mm_native.h>
|
||||
#include <ATen/ops/orgqr_native.h>
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/stack.h>
|
||||
#include <ATen/ops/triangular_solve_native.h>
|
||||
|
|
@ -1235,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper
|
|||
}
|
||||
}
|
||||
|
||||
static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) {
|
||||
if (self.numel() == 0) {
|
||||
return self;
|
||||
}
|
||||
|
||||
auto m = self.size(-2);
|
||||
auto n = self.size(-1);
|
||||
auto k = tau.size(-1);
|
||||
|
||||
if (tau.numel() == 0) {
|
||||
auto I = eye(m, self.scalar_type(), std::nullopt, self.device());
|
||||
return self.copy_(I.slice(-1, 0, n));
|
||||
}
|
||||
|
||||
auto num_batch_dims = self.dim() - 2;
|
||||
auto batch_sizes = self.sizes().slice(0, num_batch_dims);
|
||||
|
||||
std::vector<int64_t> H_sizes(num_batch_dims + 2);
|
||||
for (auto dim : c10::irange(num_batch_dims)) {
|
||||
H_sizes[dim] = self.size(dim);
|
||||
}
|
||||
H_sizes[num_batch_dims] = m;
|
||||
H_sizes[num_batch_dims + 1] = m;
|
||||
|
||||
auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous));
|
||||
auto H_prod = at::empty_like(H);
|
||||
|
||||
OrgqrParams params;
|
||||
|
||||
params.num_batch_dims = num_batch_dims;
|
||||
params.m = m;
|
||||
params.n = n;
|
||||
params.k = k;
|
||||
|
||||
for (const auto dim : c10::irange(self.dim())) {
|
||||
params.A_strides[dim] = self.stride(dim);
|
||||
|
||||
if (dim < tau.dim()) {
|
||||
params.tau_strides[dim] = tau.stride(dim);
|
||||
}
|
||||
|
||||
params.H_strides[dim] = H.stride(dim);
|
||||
params.H_sizes[dim] = H.size(dim);
|
||||
}
|
||||
|
||||
auto num_threads = H.numel();
|
||||
MPSStream* stream = getCurrentMPSStream();
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> compute_encoder = stream->commandEncoder();
|
||||
auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self)));
|
||||
getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau});
|
||||
[compute_encoder setComputePipelineState:pipeline_state];
|
||||
mtl_setArgs(compute_encoder, self, tau, H, H_prod, params);
|
||||
mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads);
|
||||
getMPSProfiler().endProfileKernel(pipeline_state);
|
||||
}
|
||||
});
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) {
|
||||
|
|
@ -1471,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const
|
|||
}
|
||||
|
||||
REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl)
|
||||
REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl);
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -14362,12 +14362,12 @@
|
|||
python_module: linalg
|
||||
variants: function
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product
|
||||
CPU, CUDA, MPS: linalg_householder_product
|
||||
|
||||
- func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)
|
||||
python_module: linalg
|
||||
dispatch:
|
||||
CPU, CUDA: linalg_householder_product_out
|
||||
CPU, CUDA, MPS: linalg_householder_product_out
|
||||
|
||||
- func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
|
||||
python_module: linalg
|
||||
|
|
|
|||
|
|
@ -328,6 +328,21 @@ struct pair {
|
|||
T2 second;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static T conj(T a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
half2 conj(half2 a) {
|
||||
return half2(a.x, -a.y);
|
||||
}
|
||||
|
||||
template <>
|
||||
float2 conj(float2 a) {
|
||||
return float2(a.x, -a.y);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_ALL_TYPES(MACRO) \
|
||||
MACRO(float); \
|
||||
MACRO(half); \
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ if torch.backends.mps.is_available():
|
|||
"item",
|
||||
"kron",
|
||||
"linalg.diagonal",
|
||||
"linalg.householder_product",
|
||||
"linalg.svd",
|
||||
"log10",
|
||||
"log1p",
|
||||
|
|
@ -322,7 +323,6 @@ if torch.backends.mps.is_available():
|
|||
"linalg.cond": None,
|
||||
"linalg.eigh": None,
|
||||
"linalg.eigvalsh": None,
|
||||
"linalg.householder_product": None,
|
||||
"linalg.ldl_factor": None,
|
||||
"linalg.ldl_factor_ex": None,
|
||||
"linalg.ldl_solve": None,
|
||||
|
|
@ -684,6 +684,7 @@ if torch.backends.mps.is_available():
|
|||
"_upsample_bilinear2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
|
||||
"_upsample_bicubic2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS
|
||||
"sparse.mmreduce": [torch.float32], # csr not supported
|
||||
"linalg.householder_product": None,
|
||||
"unique_consecutive": [torch.float16, torch.float32],
|
||||
"scalar_tensor": [torch.float16, torch.float32],
|
||||
"cdist": [torch.float32],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user