mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS] Sparse mul enable tests and fix on MPS (#166164)
Apparently mul tests in test_sparse were disabled. The dense representation i.e. when nnz is not a scalar was broken on MPS. This PR fixes it and enables the tests in test_sparse.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/166164 Approved by: https://github.com/malfet
This commit is contained in:
parent
0db6bcc015
commit
fa6d911dda
|
|
@ -33,7 +33,7 @@ using namespace mps;
|
|||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
#else
|
||||
#include <ATen/native/mps/Mul_metallib.h>
|
||||
#include <ATen/native/mps/SparseTensorMath_metallib.h>
|
||||
#endif
|
||||
|
||||
static Tensor& s_addmm_out_sparse_dense_mps(
|
||||
|
|
@ -369,12 +369,7 @@ static SparseTensor& mul_out_dense_sparse_mps(
|
|||
}
|
||||
|
||||
if (scalar_like) {
|
||||
auto scalar = dense;
|
||||
if (dense.numel() == 1 && dense.dim() > 0) {
|
||||
scalar = dense.view({});
|
||||
}
|
||||
scalar = scalar.to(values.options());
|
||||
auto out_vals = values.mul(scalar);
|
||||
auto out_vals = values.mul(dense.to(values.options()));
|
||||
if (out.scalar_type() != commonDtype) {
|
||||
out_vals = out_vals.to(out.scalar_type());
|
||||
}
|
||||
|
|
@ -508,14 +503,14 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
|||
const auto device = r_.device();
|
||||
auto stream = getCurrentMPSStream();
|
||||
|
||||
auto lhs_indices = lhs._indices();
|
||||
auto rhs_indices = rhs._indices();
|
||||
auto lhs_values = lhs._values().to(commonDtype);
|
||||
auto rhs_values = rhs._values().to(commonDtype);
|
||||
auto lhs_indices = lhs._indices().contiguous();
|
||||
auto rhs_indices = rhs._indices().contiguous();
|
||||
auto lhs_values = lhs._values().to(commonDtype).contiguous();
|
||||
auto rhs_values = rhs._values().to(commonDtype).contiguous();
|
||||
|
||||
// Flatten sparse indices to keys
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes());
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes());
|
||||
auto lhs_keys = flatten_indices(lhs_indices, lhs.sizes().slice(0, ndim_i));
|
||||
auto rhs_keys = flatten_indices(rhs_indices, rhs.sizes().slice(0, ndim_i));
|
||||
|
||||
// Intersect sorted keys (search the shorter in the longer)
|
||||
const bool A_is_lhs = (lhs_nnz <= rhs_nnz);
|
||||
|
|
@ -546,35 +541,54 @@ SparseTensor& mul_out_sparse_mps(const Tensor& t_, const Tensor& src_, SparseTen
|
|||
auto out_indices = at::empty({ndim_i, static_cast<int64_t>(M)}, at::device(device).dtype(at::kLong));
|
||||
auto lhs_match = outA_idx.narrow(0, 0, M);
|
||||
auto rhs_match = outB_idx.narrow(0, 0, M);
|
||||
auto out_val_sizes = lhs_values.sizes().vec();
|
||||
out_val_sizes[0] = static_cast<int64_t>(M);
|
||||
auto dense_sizes_vec = lhs.sizes().slice(ndim_i).vec();
|
||||
int64_t cols64 = 1;
|
||||
for (auto s : dense_sizes_vec) cols64 *= s;
|
||||
const uint32_t cols = static_cast<uint32_t>(std::max<int64_t>(cols64, 1));
|
||||
|
||||
auto to2d = [&](Tensor t, int64_t nnz) -> Tensor {
|
||||
const int64_t t_cols = t.numel() / nnz;
|
||||
if (t_cols == cols64) {
|
||||
return t.view({nnz, cols64});
|
||||
}
|
||||
return t.view({nnz, 1}).expand({nnz, cols64}).contiguous();
|
||||
};
|
||||
|
||||
// make both sides 2d [nnz, cols] buffers so the kernel can index it
|
||||
auto lhs_vals2d = to2d(lhs_values, lhs_nnz);
|
||||
auto rhs_vals2d = to2d(rhs_values, rhs_nnz);
|
||||
|
||||
std::vector<int64_t> out_val_sizes;
|
||||
out_val_sizes.reserve(1 + dense_sizes_vec.size());
|
||||
out_val_sizes.push_back(static_cast<int64_t>(M));
|
||||
out_val_sizes.insert(out_val_sizes.end(), dense_sizes_vec.begin(), dense_sizes_vec.end());
|
||||
auto out_values = at::empty(out_val_sizes, lhs_values.options());
|
||||
|
||||
const uint32_t cols = static_cast<uint32_t>(
|
||||
lhs_values.numel() / std::max<int64_t>(1, lhs_nnz));
|
||||
if (M > 0) {
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
|
||||
dispatch_sync_with_rethrow(stream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto pso = lib.getPipelineStateForFunc(
|
||||
"fused_gather_mul_kernel_" + mps::scalarToMetalTypeString(lhs_values));
|
||||
auto enc = stream->commandEncoder();
|
||||
[enc setComputePipelineState:pso];
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
const uint32_t gridW = std::max<uint32_t>(cols, 1u);
|
||||
const uint32_t tgW = std::min(gridW, tew);
|
||||
MTLSize grid = MTLSizeMake(gridW, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
const uint32_t tew = pso.threadExecutionWidth;
|
||||
uint32_t tgW = std::min(cols, tew);
|
||||
MTLSize grid = MTLSizeMake(cols, 1, M);
|
||||
MTLSize tgs = MTLSizeMake(tgW, 1, 1);
|
||||
|
||||
mtl_setArgs(enc,
|
||||
lhs_values, rhs_values,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
mtl_setArgs(enc,
|
||||
lhs_vals2d, rhs_vals2d,
|
||||
lhs_match, rhs_match,
|
||||
lhs_indices, out_indices,
|
||||
out_values,
|
||||
std::array<uint32_t, 2>{static_cast<uint32_t>(ndim_i), static_cast<uint32_t>(lhs_nnz)},
|
||||
std::array<uint32_t, 2>{M, cols});
|
||||
[enc dispatchThreads:grid threadsPerThreadgroup:tgs];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (r_.scalar_type() != commonDtype) {
|
||||
out_values = out_values.to(r_.scalar_type());
|
||||
|
|
|
|||
|
|
@ -195,9 +195,9 @@ kernel void fused_gather_mul_kernel(
|
|||
const ulong offR = (ulong)iR * (ulong)view_cols + (ulong)col;
|
||||
const ulong offO = (ulong)k * (ulong)view_cols + (ulong)col;
|
||||
|
||||
const float a = (float)lhs_vals[offL];
|
||||
const float b = (float)rhs_vals[offR];
|
||||
out_vals[offO] = (T)(a * b);
|
||||
const auto a = static_cast<accum_t<T>>(lhs_vals[offL]);
|
||||
const auto b = static_cast<accum_t<T>>(rhs_vals[offR]);
|
||||
out_vals[offO] = static_cast<T>(mul(a, b));
|
||||
}
|
||||
|
||||
// One thread per match copies the indices column
|
||||
|
|
@ -1712,14 +1712,14 @@ class TestSparse(TestSparseBase):
|
|||
a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
|
||||
b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True)
|
||||
|
||||
self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense(), masked=True)
|
||||
self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense())
|
||||
gradcheck(lambda x, y: (x * y).to_dense(), [a, b], eps=1e-4)
|
||||
# Issues with 0-dim indices/values
|
||||
gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, eps=1e-4)
|
||||
gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, eps=3e-4, atol=5e-5)
|
||||
|
||||
# TODO: Re-enable these
|
||||
# test_shape(2, 3, [2, 3, 4, 5])
|
||||
# test_shape(2, 3, [2, 2, 0])
|
||||
test_shape(2, 3, [2, 3, 4, 5])
|
||||
test_shape(2, 3, [2, 2, 0])
|
||||
test_shape(2, 3, [4, 5])
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user