[MPSInductor] Implement atomic_add store mode (#151871)

Which fixes `GPUTests.test_index_put2_mps`, `GPUTests. test__unsafe_masked_index_put_accumulate_mps` and dozen of scatter/gather tests that relied on atomic_add store mode

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151871
Approved by: https://github.com/jansel, https://github.com/dcci
ghstack dependencies: #151869
This commit is contained in:
Nikita Shulga 2025-04-22 14:11:26 -07:00 committed by PyTorch MergeBot
parent 3aecf2dc52
commit 2f851ac8f8
2 changed files with 10 additions and 12 deletions

View File

@ -1854,7 +1854,6 @@ class CommonTemplate:
),
)
@xfail_if_mps
def test__unsafe_masked_index_put_accumulate(self):
def fn(a, mask, idx, values):
return aten._unsafe_masked_index_put_accumulate(a, mask, idx, values)
@ -5775,7 +5774,6 @@ class CommonTemplate:
a = torch.rand((1, 1000000), device=self.device)
self.common(f, (a,))
@xfail_if_mps # 100% results are wrong
def test_gather_scatter(self):
def fn(node_feat, edge_index):
src_node_feat = node_feat[edge_index[0]]
@ -7942,7 +7940,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)]
)
@xfail_if_mps # 100% entries are wrong
def test_index_put2(self):
def fn(a, b, c):
return torch.index_put(a, [b], c, True)
@ -8386,7 +8383,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
],
)
@xfail_if_mps
def test_scatter2(self):
if self.device == "cuda":
raise unittest.SkipTest("unstable on sm86")
@ -8409,7 +8405,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps
def test_scatter3(self):
def fn(a, dim, index, b):
return aten.scatter(a, dim, index, b, reduce="add")
@ -8454,7 +8449,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps
def test_scatter5(self):
def fn(a, dim, index, b, reduce):
a = a.clone()
@ -8524,7 +8518,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps # All elements are wrong
def test_scatter_add2(self):
def fn(a, dim, index, b):
return aten.scatter_add(a, dim, index, b)
@ -8544,7 +8537,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps
def test_scatter_add3(self):
def fn(a, dim, index, b):
return aten.scatter_add(a, dim, index, b)
@ -8569,7 +8561,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps
def test_scatter_reduce1(self):
def fn(a, dim, index, b):
return aten.scatter_reduce(a, dim, index, b, "sum")
@ -8589,7 +8580,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps # 50% of elements are wrong
def test_scatter_reduce2(self):
def fn(a, dim, index, b, reduce):
return aten.scatter_reduce(a, dim, index, b, reduce, include_self=False)
@ -8611,7 +8601,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
check_lowp=check_lowp,
)
@xfail_if_mps
def test_scatter_reduce3(self):
def fn(a, dim, index, b, reduce):
a = a.clone()

View File

@ -529,7 +529,15 @@ class MetalKernel(SIMDKernel):
var = self.args.output(name)
index = self.prepare_indexing(index)
dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
cast_val = f"static_cast<{dtype_str}>({value})"
if mode is None:
line = f"{var}[{self.index_to_str(index)}] = {cast_val};"
elif mode == "atomic_add":
atomic_type = f"c10::metal::AtomicType<{dtype_str}>"
cast_var = f"reinterpret_cast<device {atomic_type}::type *>({var})"
line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});"
else:
raise RuntimeError(f"Unimplemented store mode {mode}")
if self.inside_reduction:
self.compute.writeline(DeferredLine(name, line))
else:
@ -785,6 +793,7 @@ class MetalKernel(SIMDKernel):
with code.indent():
code.splice(
"""
#include <c10/metal/atomic.h>
#include <c10/metal/random.h>
#include <c10/metal/special_math.h>
#include <c10/metal/utils.h>