mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPSInductor] Implement atomic_add store mode (#151871)
Which fixes `GPUTests.test_index_put2_mps`, `GPUTests. test__unsafe_masked_index_put_accumulate_mps` and dozen of scatter/gather tests that relied on atomic_add store mode Pull Request resolved: https://github.com/pytorch/pytorch/pull/151871 Approved by: https://github.com/jansel, https://github.com/dcci ghstack dependencies: #151869
This commit is contained in:
parent
3aecf2dc52
commit
2f851ac8f8
|
|
@ -1854,7 +1854,6 @@ class CommonTemplate:
|
|||
),
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test__unsafe_masked_index_put_accumulate(self):
|
||||
def fn(a, mask, idx, values):
|
||||
return aten._unsafe_masked_index_put_accumulate(a, mask, idx, values)
|
||||
|
|
@ -5775,7 +5774,6 @@ class CommonTemplate:
|
|||
a = torch.rand((1, 1000000), device=self.device)
|
||||
self.common(f, (a,))
|
||||
|
||||
@xfail_if_mps # 100% results are wrong
|
||||
def test_gather_scatter(self):
|
||||
def fn(node_feat, edge_index):
|
||||
src_node_feat = node_feat[edge_index[0]]
|
||||
|
|
@ -7942,7 +7940,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
fn, [torch.randn(1024, 4, 2), torch.arange(4), torch.randn(4, 1, 1)]
|
||||
)
|
||||
|
||||
@xfail_if_mps # 100% entries are wrong
|
||||
def test_index_put2(self):
|
||||
def fn(a, b, c):
|
||||
return torch.index_put(a, [b], c, True)
|
||||
|
|
@ -8386,7 +8383,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
],
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter2(self):
|
||||
if self.device == "cuda":
|
||||
raise unittest.SkipTest("unstable on sm86")
|
||||
|
|
@ -8409,7 +8405,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter3(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter(a, dim, index, b, reduce="add")
|
||||
|
|
@ -8454,7 +8449,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter5(self):
|
||||
def fn(a, dim, index, b, reduce):
|
||||
a = a.clone()
|
||||
|
|
@ -8524,7 +8518,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps # All elements are wrong
|
||||
def test_scatter_add2(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_add(a, dim, index, b)
|
||||
|
|
@ -8544,7 +8537,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter_add3(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_add(a, dim, index, b)
|
||||
|
|
@ -8569,7 +8561,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter_reduce1(self):
|
||||
def fn(a, dim, index, b):
|
||||
return aten.scatter_reduce(a, dim, index, b, "sum")
|
||||
|
|
@ -8589,7 +8580,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps # 50% of elements are wrong
|
||||
def test_scatter_reduce2(self):
|
||||
def fn(a, dim, index, b, reduce):
|
||||
return aten.scatter_reduce(a, dim, index, b, reduce, include_self=False)
|
||||
|
|
@ -8611,7 +8601,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
check_lowp=check_lowp,
|
||||
)
|
||||
|
||||
@xfail_if_mps
|
||||
def test_scatter_reduce3(self):
|
||||
def fn(a, dim, index, b, reduce):
|
||||
a = a.clone()
|
||||
|
|
|
|||
|
|
@ -529,7 +529,15 @@ class MetalKernel(SIMDKernel):
|
|||
var = self.args.output(name)
|
||||
index = self.prepare_indexing(index)
|
||||
dtype_str = self.dtype_to_str(V.graph.get_dtype(name))
|
||||
line = f"{var}[{self.index_to_str(index)}] = static_cast<{dtype_str}>({value});"
|
||||
cast_val = f"static_cast<{dtype_str}>({value})"
|
||||
if mode is None:
|
||||
line = f"{var}[{self.index_to_str(index)}] = {cast_val};"
|
||||
elif mode == "atomic_add":
|
||||
atomic_type = f"c10::metal::AtomicType<{dtype_str}>"
|
||||
cast_var = f"reinterpret_cast<device {atomic_type}::type *>({var})"
|
||||
line = f"{atomic_type}::atomic_add({cast_var}, {self.index_to_str(index)}, {cast_val});"
|
||||
else:
|
||||
raise RuntimeError(f"Unimplemented store mode {mode}")
|
||||
if self.inside_reduction:
|
||||
self.compute.writeline(DeferredLine(name, line))
|
||||
else:
|
||||
|
|
@ -785,6 +793,7 @@ class MetalKernel(SIMDKernel):
|
|||
with code.indent():
|
||||
code.splice(
|
||||
"""
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/random.h>
|
||||
#include <c10/metal/special_math.h>
|
||||
#include <c10/metal/utils.h>
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user