Revert "Fused RMSNorm implementation (#153666)"

This reverts commit e1aee86646.

Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
PyTorch MergeBot 2025-07-01 18:46:45 +00:00
parent 3a5677a380
commit 6401d1d53d
14 changed files with 184 additions and 839 deletions

View File

@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(kron);
OP_DECOMPOSE(l1_loss);
m.impl("layer_norm", native::layer_norm_symint);
m.impl("_fused_rms_norm", native::rms_norm_composite);
OP_DECOMPOSE2(ldexp, Tensor);
OP_DECOMPOSE2(less_equal, Tensor );
OP_DECOMPOSE2(less, Tensor );

File diff suppressed because it is too large Load Diff

View File

@ -261,11 +261,30 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
return outputs;
}
std::tuple<Tensor, Tensor> rms_norm_composite(
Tensor rms_norm_symint(
const Tensor& input,
IntArrayRef normalized_shape,
c10::SymIntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
std::optional<double> eps) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
#ifdef USE_MPS
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
const Tensor weight = weight_opt.value();
const bool any_nested = input.is_nested() || weight.is_nested();
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
const bool is_input_fp = isFloatingType(input.scalar_type());
const bool is_weight_fp = isFloatingType(weight.scalar_type());
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
}
}
#endif
std::vector<int64_t> dims_to_reduce;
for (const auto i : c10::irange(normalized_shape.size())) {
@ -302,60 +321,10 @@ std::tuple<Tensor, Tensor> rms_norm_composite(
upcasted_result = upcasted_result.mul(weight_opt.value());
}
// if nested do not make contiguous
if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){
return std::make_tuple(upcasted_result, rqrst_input);
}
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
return std::make_tuple(upcasted_result, rqrst_input);
}
return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous());
return upcasted_result;
});
return std::make_tuple(
std::get<0>(result).type_as(input), // Cast normalized result to original input type
std::get<1>(result) // rsqrt_val
);
return result.type_as(input);
}
Tensor rms_norm_symint(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
const std::optional<double> eps) {
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
// composite fallback for channels last
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
// composite fallback for complex datatypes
if(input.is_complex()){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
#ifdef USE_MPS
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
const Tensor weight = weight_opt.value();
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
if (!(GradMode::is_enabled() && any_inputs_require_grad)) {
return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
}
if (input.device().type() == DeviceType::MPS){
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
#endif
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
}
} // namespace at::native

View File

@ -106,12 +106,6 @@ void layer_norm_cpu_out(
int64_t M,
int64_t N);
std::tuple<Tensor, Tensor> rms_norm_composite(
const Tensor& input,
IntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt /* optional */,
std::optional<double> eps);
Tensor rms_norm_symint(
const Tensor& input,
c10::SymIntArrayRef normalized_shape,

View File

@ -19,14 +19,7 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/RMSNorm_metallib.h>
#endif
std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
IntArrayRef normalized_shape,
const std::optional<Tensor>& weight_opt,
const std::optional<double> eps) {
const Tensor weight = weight_opt.value().contiguous();
const int64_t normalized_ndim = normalized_shape.size();
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
auto output = at::empty_like(input);
const auto input_shape = input.sizes();
@ -48,7 +41,7 @@ std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output));
id<MTLComputePipelineState> rms_norm_pso = lib.getPipelineStateForFunc(kernel);
[computeEncoder setComputePipelineState:rms_norm_pso];
mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1);
mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1);
const auto maxThreadsPerGroup = static_cast<size_t>([rms_norm_pso maxTotalThreadsPerThreadgroup]);
size_t threadgroup_size = maxThreadsPerGroup;
@ -65,7 +58,7 @@ std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
}
});
return std::make_tuple(output, Tensor());
return output;
}
} // namespace at::native

View File

@ -3315,15 +3315,9 @@
dispatch:
CompositeImplicitAutograd: rms_norm_symint
- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
dispatch:
CUDA: _fused_rms_norm_cuda
MPS: _fused_rms_norm_mps
CompositeImplicitAutograd: rms_norm_composite
- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)
dispatch:
CUDA: _fused_rms_norm_backward_cuda
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
variants: function, method

View File

@ -374,6 +374,7 @@ aten::_fused_adamw_.tensor_lr
aten::_fused_moving_avg_obs_fq_helper
aten::_fused_moving_avg_obs_fq_helper.out
aten::_fused_moving_avg_obs_fq_helper_functional
aten::_fused_rms_norm
aten::_fused_sdp_choice
aten::_fused_sgd
aten::_fused_sgd.out

View File

@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._export.utils import _is_cia_op
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCPU,
@ -1226,33 +1226,6 @@ class DecompOneOffTests(TestCase):
for o_ref, o in zip(out_ref, out):
self.assertEqual(o_ref.dtype, o.dtype)
@onlyCUDA
@unittest.skipIf(not SM70OrLater, "triton")
def test_rms_norm_decomp_cuda(self, device):
@torch.compile
def rms_norm_sinh(a, b, c):
output = torch.nn.functional.rms_norm(a, b, c)
return torch.sinh(output)
normalized_shape_arg = (3, 3, 3)
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
def forward_pass_fn():
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
forward_pass_fn
)
# check RMSNorm was fused with sinh
self.assertTrue(
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
)
self.assertTrue(
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
)
instantiate_device_type_tests(DecompOneOffTests, globals())

View File

@ -1267,11 +1267,6 @@
mean: not_implemented("native_layer_norm_backward mean")
rstd: not_implemented("native_layer_norm_backward rstd")
- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple<Tensor, Tensor>())"
result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape)
result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape)
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)

View File

@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
aten.native_dropout_backward,
aten.native_group_norm_backward,
aten.native_layer_norm_backward,
aten._fused_rms_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,

View File

@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out(
return grad_input
@register_decomposition(aten._fused_rms_norm_backward.default)
def _fused_rms_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: list[int],
rstd: Tensor,
weight: Optional[Tensor],
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
grad_out_cast = grad_out.to(
computation_dtype, memory_format=torch.contiguous_format
)
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
weight_cast = (
weight.to(computation_dtype, memory_format=torch.contiguous_format)
if weight is not None
else None
)
assert grad_out_cast is not None
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices: list[int] = []
outer_dim_indices: list[int] = []
for i in range(input_ndim):
if i >= axis:
inner_dim_indices.append(i)
else:
outer_dim_indices.append(i)
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
return (
input.new_zeros(input_shape) if output_mask[0] else None,
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
)
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast
else:
grad_x_hat = grad_out_cast
d_input: Optional[Tensor] = None
d_weight: Optional[Tensor] = None
x_hat = input_cast * rstd
if output_mask[0]:
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
if output_mask[1] and weight_cast is not None:
d_weight_full_shape = grad_out_cast * x_hat
if len(outer_dim_indices) > 0:
d_weight = torch.sum(
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
)
else:
d_weight = d_weight_full_shape
return (
_maybe_cast(d_input, input.dtype),
_maybe_cast(d_weight, input.dtype),
)
def native_batch_norm_helper(
input: Tensor,
weight: Optional[Tensor],

View File

@ -5022,103 +5022,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
}
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
const Tensor& dY,
const Tensor& drstd,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& rstd,
const std::optional<Tensor>& weight_opt,
std::array<bool, 2> grad_input_mask) {
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
const auto input_shape = input.sizes();
const auto input_ndim = input.dim();
const int normalized_ndim = normalized_shape.size();
const int axis = input_ndim - normalized_ndim;
int64_t N_rms = 1;
for (int i = 0; i < normalized_ndim; ++i) {
N_rms *= input_shape[axis + i];
}
Tensor dX;
Tensor dgamma;
std::vector<int64_t> rstd_view_shape = rstd.sizes().vec();
for (int i = 0;
i < std::max(static_cast<int>(normalized_ndim - rstd.dim()), 0);
++i) {
rstd_view_shape.push_back(1);
}
Tensor rstd_broadcast = rstd.view(rstd_view_shape);
Tensor rstd_pow3 = rstd_broadcast.pow(3);
Tensor grad_x_hat;
if (dY.defined()) {
if (weight.defined()) {
grad_x_hat = dY * weight;
} else {
grad_x_hat = dY;
}
}
if (grad_input_mask[0]) {
Tensor dX_from_dY_path;
Tensor dX_from_drstd_path;
std::vector<int64_t> inner_sum_dims;
inner_sum_dims.reserve(normalized_ndim);
for (int i = 0; i < normalized_ndim; ++i) {
inner_sum_dims.push_back(axis + i);
}
if (dY.defined() && grad_x_hat.defined()) {
Tensor sum_input_times_grad_x_hat =
sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true);
dX_from_dY_path = rstd_broadcast * grad_x_hat -
(input * rstd_pow3 / static_cast<double>(N_rms)) *
sum_input_times_grad_x_hat;
}
if (drstd.defined()) {
Tensor drstd_broadcast = drstd.view(rstd_view_shape);
dX_from_drstd_path =
-(input * rstd_pow3 / static_cast<double>(N_rms)) * drstd_broadcast;
}
if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) {
dX = dX_from_dY_path + dX_from_drstd_path;
} else if (dX_from_dY_path.defined()) {
dX = dX_from_dY_path;
} else if (dX_from_drstd_path.defined()) {
dX = dX_from_drstd_path;
}
}
if (grad_input_mask[1] && weight.defined()) {
if (dY.defined()) {
Tensor x_hat = input * rstd_broadcast;
Tensor dgamma_full_shape = dY * x_hat;
if (axis > 0) {
std::vector<int64_t> outer_sum_dims;
outer_sum_dims.reserve(axis);
for (int i = 0; i < axis; ++i) {
outer_sum_dims.push_back(i);
}
dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false);
} else {
dgamma = dgamma_full_shape;
}
}
}
return std::make_tuple(dX, dgamma);
}
std::tuple<Tensor, Tensor, Tensor>
infinitely_differentiable_native_group_norm_backward(
const Tensor& dY,
@ -6473,98 +6376,6 @@ Tensor layer_norm_jvp(
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
}
Tensor rms_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();
int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(static_cast<int64_t>(i));
view_size[i] = 1;
dims.push_back(static_cast<int64_t>(i));
}
}
auto rstd_p = saved_rstd.view(view_size);
Tensor rstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
} else {
rstd_t = input_t * input_p;
rstd_t *= -rstd_p.pow(3);
}
rstd_t = rstd_t.sum(dims, true);
rstd_t /= numel;
Tensor result_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
result_t = (input_t)*rstd_p + (input_p)*rstd_t;
} else {
result_t = input_t * rstd_p;
auto temp = input_p * rstd_t;
result_t += temp;
}
std::optional<Tensor> result_p = std::nullopt;
if (weight_p.defined()) {
result_p = std::optional<Tensor>(input_p * rstd_p);
}
return _affine_jvp(
result_p,
result_t,
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
Tensor());
}
Tensor rms_norm_rstd_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape) {
auto dims = std::vector<int64_t>{};
auto view_size = input_t.sizes().vec();
auto view_size_affine = input_t.sizes().vec();
int64_t numel = 1;
for (const auto i : c10::irange(view_size.size())) {
if (i < view_size.size() - normalized_shape.size()) {
view_size_affine[i] = 1;
} else {
numel *= input_t.size(static_cast<int64_t>(i));
view_size[i] = 1;
dims.push_back(static_cast<int64_t>(i));
}
}
auto rstd_p = saved_rstd.view(view_size);
Tensor rstd_t;
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
input_t._is_zerotensor()) {
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
} else {
rstd_t = input_t * input_p;
rstd_t *= -rstd_p.pow(3);
}
rstd_t = rstd_t.sum(dims, true);
rstd_t /= numel;
return rstd_t;
}
Tensor group_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,

View File

@ -828,15 +828,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
c10::SymIntArrayRef normalized_shape,
std::array<bool, 3> output_mask);
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
const Tensor& dY,
const Tensor& drstd,
const Tensor& input,
IntArrayRef normalized_shape,
const Tensor& rstd,
const std::optional<Tensor>& weight_opt,
std::array<bool, 2> grad_input_mask);
std::tuple<Tensor, Tensor> householder_product_backward(
const Tensor& grad,
const Tensor& result,
@ -976,20 +967,6 @@ Tensor layer_norm_jvp(
const Tensor& saved_invstd,
c10::SymIntArrayRef normalized_shape);
Tensor rms_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& weight_p,
const Tensor& weight_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape);
Tensor rms_norm_rstd_jvp(
const Tensor& input_p,
const Tensor& input_t,
const Tensor& saved_rstd,
IntArrayRef normalized_shape);
Tensor group_norm_jvp(
const Tensor& input_p,
const Tensor& input_t,

View File

@ -820,7 +820,6 @@ def get_testing_overrides() -> dict[Callable, Callable]:
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
torch.native_dropout: lambda input, p, train: -1,
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
torch.native_channel_shuffle: lambda input, groups: -1,