Revert "[MPS] Make fused rms_norm traceable (#150661)"

This reverts commit 682f09ec51.

Reverted https://github.com/pytorch/pytorch/pull/150661 on behalf of https://github.com/malfet due to Has decomp started to fail again ([comment](https://github.com/pytorch/pytorch/pull/150661#issuecomment-2812520408))
This commit is contained in:
PyTorch MergeBot 2025-04-17 11:06:05 +00:00
parent 32c79da789
commit e4fe67f623
6 changed files with 24 additions and 20 deletions

View File

@ -16,7 +16,6 @@
#include <ATen/ops/empty_like.h>
#include <ATen/ops/empty_like_native.h>
#include <ATen/ops/layer_norm_native.h>
#include <ATen/ops/_fused_rms_norm.h>
#include <ATen/ops/native_batch_norm.h>
#include <ATen/ops/native_layer_norm.h>
#include <ATen/ops/native_layer_norm_backward_native.h>
@ -28,6 +27,7 @@
#endif
#ifdef USE_MPS
#include <ATen/native/mps/operations/RMSNorm.h>
#include <c10/core/GradMode.h>
#endif
@ -281,7 +281,7 @@ Tensor rms_norm_symint(
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
return mps::rms_norm_mps_kernel(input.contiguous(), normalized_shape, weight.contiguous(), eps_val);
}
}
#endif

View File

@ -0,0 +1,14 @@
#pragma once
#include <ATen/core/Tensor.h>
#include <c10/core/SymIntArrayRef.h>
namespace at::native::mps {
Tensor rms_norm_mps_kernel(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,
const Tensor& weight,
const double eps);
} // namespace at::native::mps

View File

@ -4,14 +4,13 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_fused_rms_norm_native.h>
#include <ATen/ops/empty_like.h>
#endif
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/operations/RMSNorm.h>
#include <fmt/format.h>
namespace at::native {
using namespace mps;
namespace at::native::mps {
#ifndef PYTORCH_JIT_COMPILE_SHADERS
static auto& lib = MetalShaderLibrary::getBundledLibrary();
@ -19,9 +18,13 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/RMSNorm_metallib.h>
#endif
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
Tensor rms_norm_mps_kernel(const Tensor& input,
c10::SymIntArrayRef normalized_shape,
const Tensor& weight,
const double eps) {
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
auto output = at::empty_like(input);
const int normalized_ndim = normalized_shape.size();
const auto input_shape = input.sizes();
const auto input_ndim = input.dim();
const int axis = input_ndim - normalized_ndim;
@ -61,4 +64,4 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
return output;
}
} // namespace at::native
} // namespace at::native::mps

View File

@ -3301,10 +3301,6 @@
dispatch:
CompositeImplicitAutograd: rms_norm_symint
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
dispatch:
MPS: _fused_rms_norm_mps
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
variants: function, method
dispatch:

View File

@ -152,14 +152,6 @@ class MPSBasicTests(TestCase):
self.common(inc_, (torch.rand(1024),))
def test_rms_norm_nograd(self):
# Regression test for https://github.com/pytorch/pytorch/issues/150629
def fn(x, w):
with torch.no_grad():
return torch.nn.functional.rms_norm(x, x.shape, w)
self.common(fn, (torch.rand(10), torch.ones(10)))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -2625,7 +2625,6 @@ make_fallback(aten.uniform, warn=False)
make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
# 1.5) Easy or Impossible