[MPS] Sparse mul enable tests and fix on MPS (#166164)

Apparently mul tests in test_sparse were disabled. The dense representation i.e. when nnz is not a scalar was broken on MPS. This PR fixes it and enables the tests in test_sparse.py

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166164
Approved by: https://github.com/malfet
This commit is contained in:
Isalia20 2025-10-24 18:30:26 +00:00 committed by PyTorch MergeBot
parent 0db6bcc015
commit fa6d911dda
3 changed files with 60 additions and 46 deletions

View File

@ -33,7 +33,7 @@ using namespace mps;
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = MetalShaderLibrary::getBundledLibrary();
#else
#include <ATen/native/mps/Mul_metallib.h>
#include <ATen/native/mps/SparseTensorMath_metallib.h>
#endif
static Tensor& s_addmm_out_sparse_dense_mps(
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
}
if (scalar_like) {
auto scalar = dense;
if (dense.numel() == 1 && dense.dim() > 0) {
scalar = dense.view({});
}
scalar = scalar.to(values.options());
auto out_vals = values.mul(scalar);
auto out_vals = values.mul(dense.to(values.options()));
if (out.scalar_type() != commonDtype) {
out_vals = out_vals.to(out.scalar_type());
}
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
const auto device = r_.device();
auto stream = getCurrentMPSStream();
auto lhs_indices = lhs._indices();
auto rhs_indices = rhs._indices();
auto lhs_values = lhs._values().to(commonDtype);
auto rhs_values = rhs._values().to(commonDtype);
auto lhs_indices = lhs._indices().contiguous();
auto rhs_indices = rhs._indices().contiguous();
auto lhs_values = lhs._values().to(commonDtype).contiguous();
auto rhs_values = rhs._values().to(commonDtype).contiguous();
// Flatten sparse indices to keys
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
// Intersect sorted keys (search the shorter in the longer)
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
auto lhs_match = outA_idx.narrow(0, 0, M);
auto rhs_match = outB_idx.narrow(0, 0, M);
auto out_val_sizes = lhs_values.sizes().vec();
out_val_sizes[0] = static_cast<int64_t>(M);
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
int64_t cols64 = 1;
for (auto s : dense_sizes_vec) cols64 *= s;
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
const int64_t t_cols = t.numel() / nnz;
if (t_cols == cols64) {
return t.view({nnz, cols64});
}
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
};
// make both sides 2d [nnz, cols] buffers so the kernel can index it
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
std::vector<int64_t> out_val_sizes;
out_val_sizes.reserve(1 + dense_sizes_vec.size());
out_val_sizes.push_back(static_cast<int64_t>(M));
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
auto out_values = at::empty(out_val_sizes, lhs_values.options());
const uint32_t cols = static_cast<uint32_t>(
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
if (M > 0) {
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc(
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto pso = lib.getPipelineStateForFunc(
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
auto enc = stream->commandEncoder();
[enc setComputePipelineState:pso];
const uint32_t tew = pso.threadExecutionWidth;
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
const uint32_t tgW = std::min(gridW, tew);
MTLSize grid = MTLSizeMake(gridW, 1, M);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
const uint32_t tew = pso.threadExecutionWidth;
uint32_t tgW = std::min(cols, tew);
MTLSize grid = MTLSizeMake(cols, 1, M);
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
mtl_setArgs(enc,
lhs_values, rhs_values,
lhs_match, rhs_match,
lhs_indices, out_indices,
out_values,
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
std::array<uint32_t, 2>{M, cols});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
mtl_setArgs(enc,
lhs_vals2d, rhs_vals2d,
lhs_match, rhs_match,
lhs_indices, out_indices,
out_values,
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
std::array<uint32_t, 2>{M, cols});
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
}
});
}
if (r_.scalar_type() != commonDtype) {
out_values = out_values.to(r_.scalar_type());

View File

@ -195,9 +195,9 @@ kernel void fused_gather_mul_kernel(
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
const float a = (float)lhs_vals[offL];
const float b = (float)rhs_vals[offR];
out_vals[offO] = (T)(a * b);
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
out_vals[offO] = static_cast<T>(mul(a, b));
}
// One thread per match copies the indices column

View File

@ -1712,14 +1712,14 @@ class TestSparse(TestSparseBase):
a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense(), masked=True)
self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense())
gradcheck(lambda x, y: (x * y).to_dense(), [a, b], eps=1e-4)
# Issues with 0-dim indices/values
gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, eps=1e-4)
gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, eps=3e-4, atol=5e-5)
# TODO: Re-enable these
# test_shape(2, 3, [2, 3, 4, 5])
# test_shape(2, 3, [2, 2, 0])
test_shape(2, 3, [2, 3, 4, 5])
test_shape(2, 3, [2, 2, 0])
test_shape(2, 3, [4, 5])
@coalescedonoff
@dtypes(torch.double)