mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[MPSInductor] Naive welford_reduce implementation (#150824)"
This reverts commit 5edfb4c4fa.
Reverted https://github.com/pytorch/pytorch/pull/150824 on behalf of https://github.com/malfet due to I should have waited for lint ([comment](https://github.com/pytorch/pytorch/pull/150824#issuecomment-2798249264))
This commit is contained in:
parent
ca2e8cd352
commit
83f14c0b06
|
|
@ -92,7 +92,6 @@ opmath_t<T> threadgroup_prod(
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
|
||||||
float m = data[0];
|
float m = data[0];
|
||||||
float m2 = 0;
|
float m2 = 0;
|
||||||
for (unsigned idx = 1; idx < size; ++idx) {
|
for (unsigned idx = 1; idx < size; ++idx) {
|
||||||
|
|
@ -103,28 +102,6 @@ float2 threadgroup_welford_reduce(threadgroup T* data, unsigned size) {
|
||||||
return float2(m, m2);
|
return float2(m, m2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Each vec3type is tuple of mean, m2 and weight
|
|
||||||
template <typename T>
|
|
||||||
float3 welford_combine(T a, T b) {
|
|
||||||
float delta = b.x - a.x;
|
|
||||||
float new_weight = a.z + b.z;
|
|
||||||
auto w2_over_w = new_weight != 0 ? b.z / new_weight : 0.0;
|
|
||||||
return float3(
|
|
||||||
a.x + delta * w2_over_w,
|
|
||||||
a.y + b.y + delta * delta * a.z * w2_over_w,
|
|
||||||
new_weight);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
float3 threadgroup_welford_combine(threadgroup T* data, unsigned size) {
|
|
||||||
::metal::threadgroup_barrier(::metal::mem_flags::mem_threadgroup);
|
|
||||||
float3 rc = data[0];
|
|
||||||
for (unsigned idx = 1; idx < size; ++idx) {
|
|
||||||
rc = welford_combine(rc, data[idx]);
|
|
||||||
}
|
|
||||||
return rc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T threadgroup_max(threadgroup T* data, unsigned size) {
|
T threadgroup_max(threadgroup T* data, unsigned size) {
|
||||||
// TODO: This should be moved to the callee
|
// TODO: This should be moved to the callee
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,6 @@ for test_name in [
|
||||||
"test_argmax_argmin2",
|
"test_argmax_argmin2",
|
||||||
"test_avg_pool2d5",
|
"test_avg_pool2d5",
|
||||||
"test_avg_pool2d8",
|
"test_avg_pool2d8",
|
||||||
"test_batch_norm_2d_2",
|
|
||||||
"test_bernoulli1",
|
"test_bernoulli1",
|
||||||
"test_builtins_round",
|
"test_builtins_round",
|
||||||
"test_builtins_round_float_ndigits_neg",
|
"test_builtins_round_float_ndigits_neg",
|
||||||
|
|
|
||||||
|
|
@ -510,22 +510,19 @@ class MetalKernel(SIMDKernel):
|
||||||
|
|
||||||
def _new_idxvar(
|
def _new_idxvar(
|
||||||
self,
|
self,
|
||||||
dtype: Union[str | torch.dtype],
|
dtype: torch.dtype,
|
||||||
elem_count: Optional[int] = None,
|
elem_count: Optional[int] = None,
|
||||||
default_value: Optional[Any] = None,
|
default_value: Optional[Any] = None,
|
||||||
is_threadgroup: bool = True,
|
is_threadgroup: bool = True,
|
||||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||||
) -> CSEVariable:
|
) -> CSEVariable:
|
||||||
if isinstance(dtype, torch.dtype):
|
|
||||||
dtype = self.dtype_to_str(dtype)
|
|
||||||
var_name = f"tmp_acc_{next(self.acc_var_ids)}"
|
var_name = f"tmp_acc_{next(self.acc_var_ids)}"
|
||||||
var = V.kernel.create_cse_var(var_name, bounds, dtype)
|
var = V.kernel.create_cse_var(var_name, bounds, dtype)
|
||||||
var_def = "threadgroup " if is_threadgroup else ""
|
var_def = "threadgroup " if is_threadgroup else ""
|
||||||
var_def += f"{dtype} {var_name}"
|
var_def += f"{self.dtype_to_str(dtype)} {var_name}"
|
||||||
if elem_count:
|
if elem_count:
|
||||||
var_def += f"[{elem_count}]"
|
var_def += f"[{elem_count}]"
|
||||||
if default_value is not None:
|
if default_value is not None:
|
||||||
assert not is_threadgroup, "Thread group var can not have default value"
|
|
||||||
var_def += f" = {default_value}"
|
var_def += f" = {default_value}"
|
||||||
self.indexing_code.writeline(var_def + self.suffix)
|
self.indexing_code.writeline(var_def + self.suffix)
|
||||||
return var
|
return var
|
||||||
|
|
@ -644,19 +641,6 @@ class MetalKernel(SIMDKernel):
|
||||||
return OpsWrapper._unwrap(
|
return OpsWrapper._unwrap(
|
||||||
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
|
(f"{wf_res}.x", f"{wf_res}.y", self.features.reduction_numel)
|
||||||
)
|
)
|
||||||
if reduction_type == "welford_combine":
|
|
||||||
assert not self.multistage_reduction, (
|
|
||||||
f"Multistage reduction not yet supported for {reduction_type}"
|
|
||||||
)
|
|
||||||
acc_buf = self._new_idxvar("float3", acc_buf_size)
|
|
||||||
self.compute.splice(
|
|
||||||
f"{acc_buf}[{reduction_idx}] = float3({value[0]}, {value[1]}, {value[2]});"
|
|
||||||
)
|
|
||||||
wf_res = self.cse.generate(
|
|
||||||
self.compute,
|
|
||||||
f"c10::metal::threadgroup_{reduction_type}({acc_buf}, {acc_buf_size})",
|
|
||||||
)
|
|
||||||
return OpsWrapper._unwrap((f"{wf_res}.x", f"{wf_res}.y", f"{wf_res}.z"))
|
|
||||||
raise NotImplementedError(reduction_type)
|
raise NotImplementedError(reduction_type)
|
||||||
|
|
||||||
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
def codegen_iteration_ranges_entry(self, entry: IterationRangesEntry) -> None:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user