mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[MPS] Make fused rms_norm traceable (#150661)"
This reverts commit 682f09ec51.
Reverted https://github.com/pytorch/pytorch/pull/150661 on behalf of https://github.com/malfet due to Has decomp started to fail again ([comment](https://github.com/pytorch/pytorch/pull/150661#issuecomment-2812520408))
This commit is contained in:
parent
32c79da789
commit
e4fe67f623
|
|
@ -16,7 +16,6 @@
|
|||
#include <ATen/ops/empty_like.h>
|
||||
#include <ATen/ops/empty_like_native.h>
|
||||
#include <ATen/ops/layer_norm_native.h>
|
||||
#include <ATen/ops/_fused_rms_norm.h>
|
||||
#include <ATen/ops/native_batch_norm.h>
|
||||
#include <ATen/ops/native_layer_norm.h>
|
||||
#include <ATen/ops/native_layer_norm_backward_native.h>
|
||||
|
|
@ -28,6 +27,7 @@
|
|||
#endif
|
||||
|
||||
#ifdef USE_MPS
|
||||
#include <ATen/native/mps/operations/RMSNorm.h>
|
||||
#include <c10/core/GradMode.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ Tensor rms_norm_symint(
|
|||
|
||||
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
|
||||
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
|
||||
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
|
||||
return mps::rms_norm_mps_kernel(input.contiguous(), normalized_shape, weight.contiguous(), eps_val);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
14
aten/src/ATen/native/mps/operations/RMSNorm.h
Normal file
14
aten/src/ATen/native/mps/operations/RMSNorm.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
Tensor rms_norm_mps_kernel(
|
||||
const Tensor& input,
|
||||
c10::SymIntArrayRef normalized_shape,
|
||||
const Tensor& weight,
|
||||
const double eps);
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
|
@ -4,14 +4,13 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_fused_rms_norm_native.h>
|
||||
#include <ATen/ops/empty_like.h>
|
||||
#endif
|
||||
#include <ATen/native/mps/OperationUtils.h>
|
||||
#include <ATen/native/mps/operations/RMSNorm.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
namespace at::native {
|
||||
using namespace mps;
|
||||
namespace at::native::mps {
|
||||
|
||||
#ifndef PYTORCH_JIT_COMPILE_SHADERS
|
||||
static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
||||
|
|
@ -19,9 +18,13 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
|||
#include <ATen/native/mps/RMSNorm_metallib.h>
|
||||
#endif
|
||||
|
||||
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
|
||||
Tensor rms_norm_mps_kernel(const Tensor& input,
|
||||
c10::SymIntArrayRef normalized_shape,
|
||||
const Tensor& weight,
|
||||
const double eps) {
|
||||
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
|
||||
auto output = at::empty_like(input);
|
||||
const int normalized_ndim = normalized_shape.size();
|
||||
const auto input_shape = input.sizes();
|
||||
const auto input_ndim = input.dim();
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
|
|
@ -61,4 +64,4 @@ Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, c
|
|||
return output;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
} // namespace at::native::mps
|
||||
|
|
|
|||
|
|
@ -3301,10 +3301,6 @@
|
|||
dispatch:
|
||||
CompositeImplicitAutograd: rms_norm_symint
|
||||
|
||||
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
|
||||
dispatch:
|
||||
MPS: _fused_rms_norm_mps
|
||||
|
||||
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
|
||||
variants: function, method
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -152,14 +152,6 @@ class MPSBasicTests(TestCase):
|
|||
|
||||
self.common(inc_, (torch.rand(1024),))
|
||||
|
||||
def test_rms_norm_nograd(self):
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/150629
|
||||
def fn(x, w):
|
||||
with torch.no_grad():
|
||||
return torch.nn.functional.rms_norm(x, x.shape, w)
|
||||
|
||||
self.common(fn, (torch.rand(10), torch.ones(10)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -2625,7 +2625,6 @@ make_fallback(aten.uniform, warn=False)
|
|||
make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
|
||||
make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
|
||||
make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
|
||||
make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
|
||||
|
||||
|
||||
# 1.5) Easy or Impossible
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user