[MPS] Implemented masked_fill_scalar as shader (#147369)

- Move `pos_from_thread_index and `offset_from_pos` from `UnfoldBackward.metal` into `c10/metal/indexing.h` header
- Initial idea were to implement `StridedTensor` and `ConstStridedTensor` and use them to have masked_fill kernel a something simple as the following loop
```metal
ConstStridedTensor<bool> mask(mask_data, sizes, mask_strides, ndim);
if (mask[thread_index]) {
  StridedTensor<T> input(input_data, sizes, input_strides, ndim);
  input[thread_index] = val;
}
```
But though it looks elegant and works correctly, performance wise it's much slower that the existing MPS shader (see table below), as int64 divisions on M2 GPU are really slow

- Solved performance issue by implementing 3 flavors of the same shader: `dense`, that is used when both input and mask are dense tensors of the same size, `broadcast`, which is used when `mask` is leading dimensions expandable into input tensor and `strided`  which is a general purpose fallback, but still computes position in the tensors only ones. As result, perf is even better than existing MPS shader for dense and broadcast able tensors.

Performance measured on M2Pro thru different iterations of the same shader

| dtype | MPS | int64-idx | int64-inlined | 32-bit strided | 32-bit broadcasted |
| ------|------| -----|   ---- | --- | ---- |
| float32 | 2.8 msec  | 41.6 msec | 26.9 msec | 5 msec | 2.4 msec |
| float16 | 1.86 msec | 38.2 msec| 26.6 msec | 4.6 msec | 1.9 msec |
|bfloat16|1.86 msec |38.3 msec | 26.6 msec | 4.6 msec | 1.9 msec |

And benchmark script
```python
import torch

from timeit import default_timer
from itertools import product
from torch.utils.benchmark import Measurement, Timer

def bench_mask_fill(
    n,
    binary_func,
    dtype=torch.float32,
) -> Measurement:
    t = Timer(
        stmt=f"x.masked_fill(y, -17.0); torch.mps.synchronize()",
        setup=f"x,y = torch.rand(1, 20, {n}, {n}, dtype={dtype}, device='mps'), torch.ones({n}, {n}, device='mps').triu().bool()",
        globals = {'f': binary_func},
        language="python", timer=default_timer
    )
    return t.blocked_autorange()

if __name__ == "__main__":
    n = 1024
    for dtype in [torch.float32, torch.float16, torch.bfloat16]:
        eager_t = bench_mask_fill(n, torch.fmax, dtype)
        use_msec = eager_t.mean > 1e-4
        multiplier = 1e3 if use_msec else 1e6
        uname = "msec" if use_msec else "usec"
        print(f"torch.masked_fill_() {str(dtype):>14} {eager_t.mean*multiplier:>7.2f} {uname}")
```
Fixes https://github.com/pytorch/pytorch/issues/143477
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147369
Approved by: https://github.com/dcci
ghstack dependencies: #147977
This commit is contained in:
Nikita Shulga 2025-02-26 10:02:34 -08:00 committed by PyTorch MergeBot
parent ebf6b9839c
commit 00732c3f7e
6 changed files with 179 additions and 99 deletions

View File

@ -381,6 +381,10 @@ static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Cont
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx];
}
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const MPSScalar& s, unsigned idx) {
[encoder setBytes:&s.value length:s.size atIndex:idx];
}
namespace detail {
template <typename T>
inline void mtl_setArg(id<MTLComputeCommandEncoder> encoder, const T& val, unsigned idx) {

View File

@ -1,7 +1,9 @@
#include <c10/metal/indexing.h>
#include <metal_atomic>
#include <metal_stdlib>
using namespace metal;
using namespace c10::metal;
struct IndexAB {
constant int64_t* indexArray;
@ -315,3 +317,66 @@ index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
device void* outputData [[buffer(5)]],
constant uint32_t& num_indices [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]);
template <typename T>
kernel void masked_fill_scalar_dense(
device T* input,
constant bool* mask,
constant T& val,
uint thread_index [[thread_position_in_grid]]) {
if (mask[thread_index]) {
input[thread_index] = val;
}
}
template <typename T>
kernel void masked_fill_scalar_broadcast(
device T* input,
constant bool* mask,
constant T& val,
constant uint& mask_numel,
uint thread_index [[thread_position_in_grid]]) {
if (mask[thread_index % mask_numel]) {
input[thread_index] = val;
}
}
template <typename T>
kernel void masked_fill_scalar_strided(
device T* input,
constant bool* mask,
constant T& val,
constant long* sizes,
constant long* input_strides,
constant long* mask_strides,
device uint& ndim,
uint thread_index [[thread_position_in_grid]]) {
int pos[max_ndim];
pos_from_thread_index(int(thread_index), pos, sizes, ndim);
if (mask[offset_from_coord(pos, mask_strides, ndim)]) {
input[offset_from_coord(pos, input_strides, ndim)] = val;
}
}
#define REGISTER_MASKED_FILL_SCALAR(SIZE, DTYPE) \
template [[host_name("masked_fill_scalar_strided_" #SIZE)]] kernel void \
masked_fill_scalar_strided<DTYPE>( \
device DTYPE*, \
constant bool*, \
constant DTYPE&, \
constant long*, \
constant long*, \
constant long*, \
device uint&, \
uint); \
template [[host_name("masked_fill_scalar_dense_" #SIZE)]] kernel void \
masked_fill_scalar_dense<DTYPE>( \
device DTYPE*, constant bool*, constant DTYPE&, uint); \
template [[host_name("masked_fill_scalar_broadcast_" #SIZE)]] kernel void \
masked_fill_scalar_broadcast<DTYPE>( \
device DTYPE*, constant bool*, constant DTYPE&, constant uint&, uint)
REGISTER_MASKED_FILL_SCALAR(64bit, long);
REGISTER_MASKED_FILL_SCALAR(32bit, int);
REGISTER_MASKED_FILL_SCALAR(16bit, short);
REGISTER_MASKED_FILL_SCALAR(8bit, char);

View File

@ -1,27 +1,7 @@
#include <c10/metal/indexing.h>
#include <metal_stdlib>
using namespace metal;
// Given coordinates and strides, calculates offset from the start of the
// tensors
long offset_from_coord(thread long* idx, constant long* strides, uint ndim) {
long rc = 0;
for (uint i = 0; i < ndim; ++i) {
rc += idx[i] * strides[i];
}
return rc;
}
// Given thread index calculates position in the ndim tensor
void pos_from_thread_index(
long idx,
thread long* pos,
constant long* sizes,
uint ndim) {
for (uint i = 0; i < ndim; ++i) {
pos[i] = idx % sizes[i];
idx /= sizes[i];
}
}
using namespace c10::metal;
// Consider out = in.unfold(dim, size, step), then
// out.shape[dim] == (in.shape[dim] - size) / step + 1,
@ -52,8 +32,8 @@ kernel void unfold_backward(
auto size = dim_size_step_ndim.y;
auto step = dim_size_step_ndim.z;
auto ndim = dim_size_step_ndim.w;
long pos[16];
pos_from_thread_index(thread_index, pos, output_sizes, ndim);
long pos[max_ndim];
pos_from_thread_index(long(thread_index), pos, output_sizes, ndim);
const auto output_offs = offset_from_coord(pos, output_strides, ndim);
const auto in_dim_size = max(1L, (output_sizes[dim] - size) / step + 1);
const auto out_dim_idx = pos[dim];

View File

@ -1,5 +1,7 @@
// Copyright © 2022 Apple Inc.
#include <limits>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch_v2.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/AccumulateType.h>
@ -20,6 +22,7 @@
#include <c10/core/QScheme.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
#include <fmt/format.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -688,10 +691,28 @@ Tensor& index_select_out_mps(const Tensor& self, int64_t dim, const Tensor& inde
return output;
}
// Checks if one tensor is broadcastable into another
static bool is_dense_broadcastable(const Tensor& from, const Tensor& into) {
if (!from.is_contiguous() || !into.is_contiguous()) {
return false;
}
bool checking_squeezable_dims = false;
for (const auto dim : c10::irange(from.ndimension())) {
if (checking_squeezable_dims) {
if (from.size(-dim - 1) == 1) {
continue;
}
return false;
}
checking_squeezable_dims = from.size(-dim - 1) != into.size(-dim - 1);
}
return true;
}
Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Scalar& value) {
using namespace mps;
if (self.numel() == 0) {
if (self.numel() == 0 || mask.numel() == 0) {
return self;
}
TORCH_CHECK(self.device() == mask.device(),
@ -700,78 +721,38 @@ Tensor& masked_fill__mps(Tensor& self, const Tensor& mask, const Scalar& value)
" and self on ",
self.device());
TORCH_CHECK(mask.scalar_type() == kBool, "expected mask dtype to be Bool but got ", mask.scalar_type());
TORCH_CHECK(self.numel() <= std::numeric_limits<uint32_t>::max(),
"masked_fill not supported for tensors of more than 2**32 elements");
auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
bool needs_output_copy = false;
Tensor output;
if (needsGather(self)) {
output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
needs_output_copy = true;
}
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* maskTensor_ = nil;
MPSGraphTensor* valueTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
MPSDataType inputDataType = getMPSScalarType(self.scalar_type());
MPSDataType maskDataType = getMPSScalarType(b_mask->scalar_type());
MPSStream* stream = getCurrentMPSStream();
MPSScalar valueScalar = getMPSScalar(value, value.type());
@autoreleasepool {
string key = "masked_fill" + getTensorsStringKey({self, *b_mask}) + ":" + getMPSTypeString(value.type());
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputDataType, getMPSShape(self));
MPSGraphTensor* maskTensor = mpsGraphRankedPlaceHolder(mpsGraph, maskDataType, getMPSShape(*b_mask));
MPSGraphTensor* valueTensor = mpsGraphScalarPlaceHolder(mpsGraph, value);
MPSDataType valueType = getMPSScalarType(value.type());
MPSGraphTensor* castValueTensor = valueTensor;
if (valueType != inputDataType) {
castValueTensor = [mpsGraph castTensor:valueTensor toType:inputDataType name:@"castValueTensor"];
auto stream = getCurrentMPSStream();
const bool is_dense = self.is_contiguous() && b_mask->is_contiguous();
const bool is_dense_broadcast = is_dense_broadcastable(mask, self);
const auto flavor = is_dense ? "dense" : is_dense_broadcast ? "broadcast" : "strided";
auto fillPSO = lib.getPipelineStateForFunc(
fmt::format("masked_fill_scalar_{}_{}", flavor, getBitSizeString(self.scalar_type())));
dispatch_sync_with_rethrow(stream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = stream->commandEncoder();
auto mpsScalar = getMPSScalar(value, self.scalar_type());
[computeEncoder setComputePipelineState:fillPSO];
if (is_dense) {
mtl_setArgs(computeEncoder, self, *b_mask, mpsScalar);
} else if (is_dense_broadcast) {
mtl_setArgs(computeEncoder, self, mask, mpsScalar, mask.numel());
} else {
mtl_setArgs(computeEncoder,
self,
*b_mask,
mpsScalar,
self.sizes(),
self.strides(),
b_mask->strides(),
self.ndimension());
}
MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor
truePredicateTensor:castValueTensor
falsePredicateTensor:inputTensor
name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->maskTensor_ = maskTensor;
newCachedGraph->valueTensor_ = valueTensor;
newCachedGraph->outputTensor_ = outputTensor;
});
Placeholder selfPlaceholder =
Placeholder(cachedGraph->inputTensor_, self, /*mpsShape*/ nil, /*gatherTensorData=*/true, inputDataType);
Placeholder maskPlaceholder =
Placeholder(cachedGraph->maskTensor_, *b_mask, /*mpsShape*/ nil, /*gatherTensorData=*/true, maskDataType);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_,
needs_output_copy ? output : self,
/*mpsShape*/ nil,
/*gatherTensorData=*/false,
inputDataType);
// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
maskPlaceholder.getMPSGraphTensor() : maskPlaceholder.getMPSGraphTensorData(),
cachedGraph->valueTensor_ : getMPSGraphTensorFromScalar(stream, valueScalar)
};
runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
}
if (needs_output_copy) {
self.copy_(output);
}
mtl_dispatch1DJob(computeEncoder, fillPSO, self.numel());
}
});
namedinference::propagate_names_if_nonempty(self, maybe_outnames);
return self;

46
c10/metal/indexing.h Normal file
View File

@ -0,0 +1,46 @@
#pragma once
#include <metal_stdlib>
namespace c10 {
namespace metal {
constant constexpr unsigned max_ndim = 16;
// Given coordinates and strides, calculates offset from the start of the
// tensors
template <typename T>
inline T offset_from_coord(
thread T idx[max_ndim],
constant long* strides,
uint ndim) {
T rc = 0;
for (uint i = 0; i < ndim; ++i) {
rc += idx[i] * T(strides[i]);
}
return rc;
}
// Given thread index calculates position in the ndim tensor
template <typename T>
inline void pos_from_thread_index(
T idx,
thread T pos[max_ndim],
constant long* sizes,
uint ndim) {
for (uint i = 0; i < ndim; ++i) {
pos[i] = idx % T(sizes[i]);
idx /= T(sizes[i]);
}
}
inline long offset_from_thread_index(
long idx,
constant long* sizes,
constant long* strides,
uint ndim) {
long pos[max_ndim];
pos_from_thread_index(idx, pos, sizes, ndim);
return offset_from_coord(pos, strides, ndim);
}
} // namespace metal
} // namespace c10

View File

@ -235,6 +235,7 @@ def mps_ops_modifier(ops):
'__radd__',
'__rmul__',
'__getitem__',
'_unsafe_masked_index',
'abs',
'add',
'alias_copy',
@ -284,6 +285,7 @@ def mps_ops_modifier(ops):
'linalg.svd',
'mH',
'mT',
'masked_fill',
'masked_scatter',
'masked_select',
'meshgridlist_of_tensors',
@ -363,7 +365,6 @@ def mps_ops_modifier(ops):
'__rdiv__',
'__rmatmul__',
'_chunk_cat',
'_unsafe_masked_index',
'acos',
'acosh',
'all',
@ -441,7 +442,6 @@ def mps_ops_modifier(ops):
'logical_xor',
'logsumexp',
'long',
'masked_fill',
'masked.mean',
'masked.prod',
'masked.std',
@ -934,9 +934,6 @@ def mps_ops_error_inputs_modifier(ops):
'scatter',
'scatter_add',
# unsupported complex dtypes
'masked_fill',
# MPS does not support tensor dimensions > 16
'amax',
'amin',
@ -2283,6 +2280,13 @@ class TestMPS(TestCaseMPS):
dst2[i] = val
self.assertEqual(dst.to("cpu"), dst2, atol=0, rtol=0)
if MACOS_VERSION >= 14.0:
# Regression test for https://github.com/pytorch/pytorch/issues/143477
# Allocating 48x25x1024x1024 tensor crashes on MacOS-13
mask_bool = torch.triu(torch.ones(1024, 1024, device=device), diagonal=1).bool()
attn_scores = torch.rand(48, 25, 1024, 1024, device=device)
attn_scores.masked_fill_(mask_bool, 0)
def test_masked_fill__non_contiguous(self):
shape = (3, 5)