diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.h b/aten/src/ATen/native/mps/kernels/LinearAlgebra.h new file mode 100644 index 00000000000..5cd66b933c1 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.h @@ -0,0 +1,16 @@ +#pragma onces +#include + +template +struct OrgqrParams { + int32_t num_batch_dims; + + uint32_t m; + uint32_t n; + uint32_t k; + + ::c10::metal::array A_strides; + ::c10::metal::array tau_strides; + ::c10::metal::array H_strides; + ::c10::metal::array H_sizes; +}; diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index 3673bd3cc48..c356dbf9ecb 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -1,3 +1,4 @@ +#include #include #include #include @@ -640,6 +641,164 @@ kernel void applyPivots( } } +template +static T bool_to_float(bool b) { + return static_cast(b); +} + +template <> +half2 bool_to_float(bool b) { + return half2(b ? 1 : 0, 0); +} + +template <> +float2 bool_to_float(bool b) { + return float2(b ? 1 : 0, 0); +} + +template +static T calc_H_irc( + device T* A, + uint32_t A_stride_r, + uint32_t A_stride_c, + constant T* tau, + uint32_t tau_stride, + uint32_t r, + uint32_t c, + uint32_t i) { + T I_val = bool_to_float(r == c); + T tau_val = tau[i * tau_stride]; + + T A_ci = c10::metal::conj(A[c * A_stride_r + i * A_stride_c]); + T A_ri = A[r * A_stride_r + i * A_stride_c]; + + T c_eq_i = bool_to_float(c == i); + T r_eq_i = bool_to_float(r == i); + + T A_ci_ = (c > i) ? A_ci : c_eq_i; + T A_ri_ = (r > i) ? A_ri : r_eq_i; + + return I_val - c10::metal::mul(tau_val, c10::metal::mul(A_ci_, A_ri_)); +} + +// Calculate (A @ B)[r, c], the element in the r-th row and c-th column of the +// result of matrix multiplying A and B together. A and B must be size m-by-m +// and have the same strides. The formula for this operation, written in Python +// syntax, is: +// (A @ B)[r, c] = A[r, :].dot(B[:, c]) +template +static T calc_matmul_rc( + device T* A, + device T* B, + uint32_t stride_r, + uint32_t stride_c, + uint32_t m, + uint32_t r, + uint32_t c) { + T AB_rc = 0; + auto A_row_offset = r * stride_r; + auto B_col_offset = c * stride_c; + + uint32_t A_col_offset = 0; + uint32_t B_row_offset = 0; + + for (uint32_t j = 0; j < m; + j++, A_col_offset += stride_c, B_row_offset += stride_r) { + AB_rc += c10::metal::mul( + A[A_row_offset + A_col_offset], B[B_row_offset + B_col_offset]); + } + return AB_rc; +} + +template +kernel void orgqr( + device T* A [[buffer(0)]], + constant T* tau [[buffer(1)]], + device T* H [[buffer(2)]], + device T* H_prod [[buffer(3)]], + constant OrgqrParams<>& params [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + constant auto& A_strides = params.A_strides; + constant auto& tau_strides = params.tau_strides; + constant auto& H_strides = params.H_strides; + constant auto& H_sizes = params.H_sizes; + + auto num_batch_dims = params.num_batch_dims; + auto m = params.m; + auto n = params.n; + auto k = params.k; + + auto m2 = m * m; + auto batch_idx = tid / m2; + + // Find the matrices for this thread's batch index + uint32_t A_offset = 0; + uint32_t tau_offset = 0; + uint32_t H_offset = 0; + + for (auto dim = num_batch_dims - 1; dim >= 0; dim--) { + auto dim_size = H_sizes[dim]; + auto dim_idx = batch_idx % dim_size; + + A_offset += dim_idx * A_strides[dim]; + tau_offset += dim_idx * tau_strides[dim]; + H_offset += dim_idx * H_strides[dim]; + + batch_idx /= dim_size; + } + + A += A_offset; + tau += tau_offset; + H += H_offset; + H_prod += H_offset; + + auto matrix_idx = tid % m2; + auto r = matrix_idx / m; + auto c = matrix_idx % m; + auto A_stride_r = A_strides[num_batch_dims]; + auto A_stride_c = A_strides[num_batch_dims + 1]; + auto tau_stride = tau_strides[num_batch_dims]; + auto H_stride_r = H_strides[num_batch_dims]; + auto H_stride_c = H_strides[num_batch_dims + 1]; + + // Find the element of H and H_prod that this thread will calculate + device T* H_elem_ptr = H + (r * H_stride_r + c * H_stride_c); + device T* H_prod_elem_ptr = H_prod + (r * H_stride_r + c * H_stride_c); + + for (uint32_t i = 0; i < k; i++) { + // Calculate and write H_i + + T H_irc = calc_H_irc(A, A_stride_r, A_stride_c, tau, tau_stride, r, c, i); + + // Calculate element [r, c] of prod(H_0, ..., H_i) + if (i == 0) { + *H_prod_elem_ptr = H_irc; + } else { + *H_elem_ptr = H_irc; + + // Need this sync because the below matmul requires all threads to finish + // writing their entries to `H_prod` and `H`. + threadgroup_barrier(mem_flags::mem_threadgroup); + + T H_prod_0_to_i_rc = + calc_matmul_rc(H_prod, H, H_stride_r, H_stride_c, m, r, c); + + // Need this sync because the above matmul uses the current values in + // `H_prod`, and we don't want to overwrite those until all threads are + // finished using them. + threadgroup_barrier(mem_flags::mem_threadgroup); + + *H_prod_elem_ptr = H_prod_0_to_i_rc; + } + } + + device T* A_elem_ptr = A + (r * A_stride_r + c * A_stride_c); + + if (c < n) { + *A_elem_ptr = *H_prod_elem_ptr; + } +} + #define INSTANTIATE_MM_OPS(DTYPE) \ template [[host_name("matmul_" #DTYPE)]] kernel void matmul( \ constant DTYPE * mat1Data [[buffer(0)]], \ @@ -679,3 +838,19 @@ INSTANTIATE_MM_OPS(int); INSTANTIATE_MM_OPS(short); INSTANTIATE_MM_OPS(char); INSTANTIATE_MM_OPS(uchar); + +#define REGISTER_ORGQR(T) \ + template [[host_name("orgqr_" #T)]] \ + kernel void orgqr( \ + device T * A [[buffer(0)]], \ + constant T * tau [[buffer(1)]], \ + device T * H [[buffer(2)]], \ + device T * H_prod [[buffer(3)]], \ + constant OrgqrParams<> & params [[buffer(4)]], \ + uint tid [[thread_position_in_grid]]); + +REGISTER_ORGQR(float); +REGISTER_ORGQR(half); +REGISTER_ORGQR(bfloat); +REGISTER_ORGQR(float2); +REGISTER_ORGQR(half2); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index 0fdcdedd6e6..aed417ca9ca 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -8,6 +8,9 @@ #include #include #include +#include + +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -28,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -1235,6 +1239,69 @@ static void cholesky_stub_impl(const Tensor& out, const Tensor& info, bool upper } } +static Tensor& orgqr_stub_impl(Tensor& self, const Tensor& tau) { + if (self.numel() == 0) { + return self; + } + + auto m = self.size(-2); + auto n = self.size(-1); + auto k = tau.size(-1); + + if (tau.numel() == 0) { + auto I = eye(m, self.scalar_type(), std::nullopt, self.device()); + return self.copy_(I.slice(-1, 0, n)); + } + + auto num_batch_dims = self.dim() - 2; + auto batch_sizes = self.sizes().slice(0, num_batch_dims); + + std::vector H_sizes(num_batch_dims + 2); + for (auto dim : c10::irange(num_batch_dims)) { + H_sizes[dim] = self.size(dim); + } + H_sizes[num_batch_dims] = m; + H_sizes[num_batch_dims + 1] = m; + + auto H = at::empty(H_sizes, self.options().memory_format(MemoryFormat::Contiguous)); + auto H_prod = at::empty_like(H); + + OrgqrParams params; + + params.num_batch_dims = num_batch_dims; + params.m = m; + params.n = n; + params.k = k; + + for (const auto dim : c10::irange(self.dim())) { + params.A_strides[dim] = self.stride(dim); + + if (dim < tau.dim()) { + params.tau_strides[dim] = tau.stride(dim); + } + + params.H_strides[dim] = H.stride(dim); + params.H_sizes[dim] = H.size(dim); + } + + auto num_threads = H.numel(); + MPSStream* stream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + auto pipeline_state = lib.getPipelineStateForFunc(fmt::format("orgqr_{}", scalarToMetalTypeString(self))); + getMPSProfiler().beginProfileKernel(pipeline_state, "orgqr", {self, tau}); + [compute_encoder setComputePipelineState:pipeline_state]; + mtl_setArgs(compute_encoder, self, tau, H, H_prod, params); + mtl_dispatch1DJob(compute_encoder, pipeline_state, num_threads); + getMPSProfiler().endProfileKernel(pipeline_state); + } + }); + + return self; +} + } // namespace mps Tensor addr_mps(const Tensor& self, const Tensor& vec1, const Tensor& vec2, const Scalar& beta, const Scalar& alpha) { @@ -1471,4 +1538,6 @@ TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const } REGISTER_DISPATCH(cholesky_stub, mps::cholesky_stub_impl) +REGISTER_DISPATCH(orgqr_stub, mps::orgqr_stub_impl); + } // namespace at::native diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 07191ec7cb8..410a0fcd4d7 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14362,12 +14362,12 @@ python_module: linalg variants: function dispatch: - CPU, CUDA: linalg_householder_product + CPU, CUDA, MPS: linalg_householder_product - func: linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) python_module: linalg dispatch: - CPU, CUDA: linalg_householder_product_out + CPU, CUDA, MPS: linalg_householder_product_out - func: linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) python_module: linalg diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 14c4b2b2cba..0c6a9724d22 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -328,6 +328,21 @@ struct pair { T2 second; }; +template +static T conj(T a) { + return a; +} + +template <> +half2 conj(half2 a) { + return half2(a.x, -a.y); +} + +template <> +float2 conj(float2 a) { + return float2(a.x, -a.y); +} + #define INSTANTIATE_FOR_ALL_TYPES(MACRO) \ MACRO(float); \ MACRO(half); \ diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index 4628b8aa3e2..ce64f1d9cdd 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -86,6 +86,7 @@ if torch.backends.mps.is_available(): "item", "kron", "linalg.diagonal", + "linalg.householder_product", "linalg.svd", "log10", "log1p", @@ -322,7 +323,6 @@ if torch.backends.mps.is_available(): "linalg.cond": None, "linalg.eigh": None, "linalg.eigvalsh": None, - "linalg.householder_product": None, "linalg.ldl_factor": None, "linalg.ldl_factor_ex": None, "linalg.ldl_solve": None, @@ -684,6 +684,7 @@ if torch.backends.mps.is_available(): "_upsample_bilinear2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS "_upsample_bicubic2d_aa": None, # `_upsample_bilinear2d_aa_backward_out` not implemented for MPS "sparse.mmreduce": [torch.float32], # csr not supported + "linalg.householder_product": None, "unique_consecutive": [torch.float16, torch.float32], "scalar_tensor": [torch.float16, torch.float32], "cdist": [torch.float32],