[MPSInductor] Naive welford_reduce implementation (#150824)

Literal Python-to-Metal translation of
85549fe6de/torch/_inductor/runtime/triton_helpers.py (L217-L225)

Fixed missing barrier in `welford_combine`
And this is sufficient to make `GPUTests.test_batch_norm_2d_2_mps` to pass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150824
Approved by: https://github.com/dcci, https://github.com/jansel
ghstack dependencies: #151042
This commit is contained in:
Nikita Shulga 2025-04-11 19:33:35 -07:00 committed by PyTorch MergeBot
parent 32f0f414ab
commit 397d37acc5
3 changed files with 45 additions and 3 deletions

View File

@ -92,6 +92,7 @@ opmath_t<T> threadgroup_prod(
template <typename T>
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float m = data[0];
float m2 = 0;
for (unsigned idx = 1; idx < size; ++idx) {
@ -102,6 +103,28 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
return float2(m, m2);
}
// Each vec3type is tuple of mean, m2 and weight
template <typename T>
float3 welford_combine(T a, T b) {
float delta = b.x - a.x;
float new_weight = a.z + b.z;
auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
return float3(
a.x + delta * w2_over_w,
a.y + b.y + delta * delta * a.z * w2_over_w,
new_weight);
}
template <typename T>
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
float3 rc = data[0];
for (unsigned idx = 1; idx < size; ++idx) {
rc = welford_combine(rc, data[idx]);
}
return rc;
}
template <typename T>
T threadgroup_max(threadgroup T* data, unsigned size) {
// TODO: This should be moved to the callee

View File

@ -175,6 +175,7 @@ for test_name in [
"test_argmax_argmin2",
"test_avg_pool2d5",
"test_avg_pool2d8",
"test_batch_norm_2d_2",
"test_bernoulli1",
"test_builtins_round",
"test_builtins_round_float_ndigits_neg",

View File

@ -510,19 +510,22 @@ class MetalKernel(SIMDKernel):
def _new_idxvar(
self,
dtype: torch.dtype,
dtype: Union[str | torch.dtype],
elem_count: Optional[int] = None,
default_value: Optional[Any] = None,
is_threadgroup: bool = True,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
) -> CSEVariable:
if isinstance(dtype, torch.dtype):
dtype = self.dtype_to_str(dtype)
var_name = f"tmp_acc_{next(self.acc_var_ids)}"
var = V.kernel.create_cse_var(var_name, bounds, dtype)
var_def = "threadgroup " if is_threadgroup else ""
var_def += f"{self.dtype_to_str(dtype)} {var_name}"
var_def += f"{dtype} {var_name}"
if elem_count:
var_def += f"[{elem_count}]"
if default_value is not None:
assert not is_threadgroup, "Thread group var can not have default value"
var_def += f" = {default_value}"
self.indexing_code.writeline(var_def + self.suffix)
return var
@ -534,7 +537,8 @@ class MetalKernel(SIMDKernel):
reduction_type: ReductionType,
value: Union[CSEVariable, tuple[CSEVariable, ...]],
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
"""Codegen a reduction operation"""
"""Codegen a reduction operation.
Only sum and prod operations are somewhat reasonable optimized"""
# Establish reduction buffer size and index expression
reduction_idx = ""
acc_buf_size = 1
@ -641,6 +645,20 @@ class MetalKernel(SIMDKernel):
return OpsWrapper._unwrap(
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
)
if reduction_type == "welford_combine":
assert not self.multistage_reduction, (
f"Multistage reduction not yet supported for {reduction_type}"
)
assert isinstance(value, tuple), "Input to welford combine must be tuple"
acc_buf = self._new_idxvar("float3", acc_buf_size)
self.compute.splice(
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
)
wf_res = self.cse.generate(
self.compute,
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
)
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
raise NotImplementedError(reduction_type)
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None: