mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Fused RMSNorm implementation (#153666)"
This reverts commite1aee86646. Reverted https://github.com/pytorch/pytorch/pull/153666 on behalf of https://github.com/davidberard98 due to causing build failures on main branch [GH job link](https://github.com/pytorch/pytorch/actions/runs/16007148842/job/45156382001) [HUD commit link](e1aee86646) ([comment](https://github.com/pytorch/pytorch/pull/153666#issuecomment-3025146176))
This commit is contained in:
parent
3a5677a380
commit
6401d1d53d
|
|
@ -158,7 +158,6 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||
OP_DECOMPOSE(kron);
|
||||
OP_DECOMPOSE(l1_loss);
|
||||
m.impl("layer_norm", native::layer_norm_symint);
|
||||
m.impl("_fused_rms_norm", native::rms_norm_composite);
|
||||
OP_DECOMPOSE2(ldexp, Tensor);
|
||||
OP_DECOMPOSE2(less_equal, Tensor );
|
||||
OP_DECOMPOSE2(less, Tensor );
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -261,11 +261,30 @@ std::tuple<Tensor, Tensor, Tensor> math_native_layer_norm(
|
|||
return outputs;
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> rms_norm_composite(
|
||||
Tensor rms_norm_symint(
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
c10::SymIntArrayRef normalized_shape,
|
||||
const std::optional<Tensor>& weight_opt /* optional */,
|
||||
std::optional<double> eps) {
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
|
||||
|
||||
#ifdef USE_MPS
|
||||
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
|
||||
const Tensor weight = weight_opt.value();
|
||||
const bool any_nested = input.is_nested() || weight.is_nested();
|
||||
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
|
||||
const bool is_input_fp = isFloatingType(input.scalar_type());
|
||||
const bool is_weight_fp = isFloatingType(weight.scalar_type());
|
||||
|
||||
if (!(GradMode::is_enabled() && any_inputs_require_grad) && !any_nested && is_input_fp && is_weight_fp) {
|
||||
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
|
||||
return at::_fused_rms_norm(input.contiguous(), normalized_shape.size(), weight.contiguous(), eps_val);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<int64_t> dims_to_reduce;
|
||||
for (const auto i : c10::irange(normalized_shape.size())) {
|
||||
|
|
@ -302,60 +321,10 @@ std::tuple<Tensor, Tensor> rms_norm_composite(
|
|||
upcasted_result = upcasted_result.mul(weight_opt.value());
|
||||
}
|
||||
|
||||
// if nested do not make contiguous
|
||||
if(input.is_nested() || (weight_opt.has_value() && weight_opt.value().is_nested())){
|
||||
return std::make_tuple(upcasted_result, rqrst_input);
|
||||
}
|
||||
|
||||
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
|
||||
return std::make_tuple(upcasted_result, rqrst_input);
|
||||
}
|
||||
|
||||
return std::make_tuple(upcasted_result.contiguous(), rqrst_input.contiguous());
|
||||
return upcasted_result;
|
||||
});
|
||||
return std::make_tuple(
|
||||
std::get<0>(result).type_as(input), // Cast normalized result to original input type
|
||||
std::get<1>(result) // rsqrt_val
|
||||
);
|
||||
|
||||
return result.type_as(input);
|
||||
|
||||
}
|
||||
|
||||
|
||||
Tensor rms_norm_symint(
|
||||
const Tensor& input,
|
||||
c10::SymIntArrayRef normalized_shape,
|
||||
const std::optional<Tensor>& weight_opt /* optional */,
|
||||
const std::optional<double> eps) {
|
||||
|
||||
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
_check_rms_norm_inputs_symint(input, normalized_shape, weight);
|
||||
|
||||
// composite fallback for channels last
|
||||
if(input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast || input.suggest_memory_format() == c10::MemoryFormat::ChannelsLast3d){
|
||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
|
||||
// composite fallback for complex datatypes
|
||||
if(input.is_complex()){
|
||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
|
||||
#ifdef USE_MPS
|
||||
if (input.device().type() == DeviceType::MPS && weight_opt.has_value()) {
|
||||
const Tensor weight = weight_opt.value();
|
||||
const bool any_inputs_require_grad = input.requires_grad() || weight.requires_grad();
|
||||
|
||||
if (!(GradMode::is_enabled() && any_inputs_require_grad)) {
|
||||
return std::get<0>(at::_fused_rms_norm(input.contiguous(), IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
}
|
||||
|
||||
if (input.device().type() == DeviceType::MPS){
|
||||
return std::get<0>(rms_norm_composite(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
#endif
|
||||
|
||||
return std::get<0>(at::_fused_rms_norm(input, IntArrayRef(reinterpret_cast<const int64_t*>(normalized_shape.data()), normalized_shape.size()), weight_opt, eps));
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -106,12 +106,6 @@ void layer_norm_cpu_out(
|
|||
int64_t M,
|
||||
int64_t N);
|
||||
|
||||
std::tuple<Tensor, Tensor> rms_norm_composite(
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const std::optional<Tensor>& weight_opt /* optional */,
|
||||
std::optional<double> eps);
|
||||
|
||||
Tensor rms_norm_symint(
|
||||
const Tensor& input,
|
||||
c10::SymIntArrayRef normalized_shape,
|
||||
|
|
|
|||
|
|
@ -19,14 +19,7 @@ static auto& lib = MetalShaderLibrary::getBundledLibrary();
|
|||
#include <ATen/native/mps/RMSNorm_metallib.h>
|
||||
#endif
|
||||
|
||||
std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const std::optional<Tensor>& weight_opt,
|
||||
const std::optional<double> eps) {
|
||||
const Tensor weight = weight_opt.value().contiguous();
|
||||
const int64_t normalized_ndim = normalized_shape.size();
|
||||
auto eps_val = eps.value_or(std::numeric_limits<double>::epsilon());
|
||||
|
||||
Tensor _fused_rms_norm_mps(const Tensor& input, const int64_t normalized_ndim, const Tensor& weight, const double eps) {
|
||||
TORCH_CHECK(input.is_contiguous() && weight.is_contiguous(), "Expected contiguous input and weight tensors");
|
||||
auto output = at::empty_like(input);
|
||||
const auto input_shape = input.sizes();
|
||||
|
|
@ -48,7 +41,7 @@ std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
|
|||
const std::string kernel = fmt::format("{}_{}", name, scalarToMetalTypeString(output));
|
||||
id<MTLComputePipelineState> rms_norm_pso = lib.getPipelineStateForFunc(kernel);
|
||||
[computeEncoder setComputePipelineState:rms_norm_pso];
|
||||
mtl_setArgs(computeEncoder, input, weight, output, eps_val, N, 1);
|
||||
mtl_setArgs(computeEncoder, input, weight, output, eps, N, 1);
|
||||
|
||||
const auto maxThreadsPerGroup = static_cast<size_t>([rms_norm_pso maxTotalThreadsPerThreadgroup]);
|
||||
size_t threadgroup_size = maxThreadsPerGroup;
|
||||
|
|
@ -65,7 +58,7 @@ std::tuple<Tensor, Tensor> _fused_rms_norm_mps(const Tensor& input,
|
|||
}
|
||||
});
|
||||
|
||||
return std::make_tuple(output, Tensor());
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -3315,15 +3315,9 @@
|
|||
dispatch:
|
||||
CompositeImplicitAutograd: rms_norm_symint
|
||||
|
||||
- func: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
|
||||
- func: _fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor
|
||||
dispatch:
|
||||
CUDA: _fused_rms_norm_cuda
|
||||
MPS: _fused_rms_norm_mps
|
||||
CompositeImplicitAutograd: rms_norm_composite
|
||||
|
||||
- func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
CUDA: _fused_rms_norm_backward_cuda
|
||||
|
||||
- func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -374,6 +374,7 @@ aten::_fused_adamw_.tensor_lr
|
|||
aten::_fused_moving_avg_obs_fq_helper
|
||||
aten::_fused_moving_avg_obs_fq_helper.out
|
||||
aten::_fused_moving_avg_obs_fq_helper_functional
|
||||
aten::_fused_rms_norm
|
||||
aten::_fused_sdp_choice
|
||||
aten::_fused_sgd
|
||||
aten::_fused_sgd.out
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from torch._dispatch.python import enable_python_dispatcher
|
|||
from torch._export.utils import _is_cia_op
|
||||
from torch._ops import DispatchKey
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import SM70OrLater, tf32_off
|
||||
from torch.testing._internal.common_cuda import tf32_off
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
|
|
@ -1226,33 +1226,6 @@ class DecompOneOffTests(TestCase):
|
|||
for o_ref, o in zip(out_ref, out):
|
||||
self.assertEqual(o_ref.dtype, o.dtype)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
def test_rms_norm_decomp_cuda(self, device):
|
||||
@torch.compile
|
||||
def rms_norm_sinh(a, b, c):
|
||||
output = torch.nn.functional.rms_norm(a, b, c)
|
||||
return torch.sinh(output)
|
||||
|
||||
normalized_shape_arg = (3, 3, 3)
|
||||
input_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||
weight_tensor = torch.randn(3, 3, 3, device=device, requires_grad=True)
|
||||
|
||||
def forward_pass_fn():
|
||||
return rms_norm_sinh(input_tensor, normalized_shape_arg, weight_tensor)
|
||||
|
||||
model_output, generated_codes = torch._inductor.utils.run_fw_bw_and_get_code(
|
||||
forward_pass_fn
|
||||
)
|
||||
|
||||
# check RMSNorm was fused with sinh
|
||||
self.assertTrue(
|
||||
"triton_per_fused_add_mean_mul_pow_rsqrt_sinh" in generated_codes[0]
|
||||
)
|
||||
self.assertTrue(
|
||||
"triton_per_fused__fused_rms_norm_backward_cosh_mul" in generated_codes[1]
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -1267,11 +1267,6 @@
|
|||
mean: not_implemented("native_layer_norm_backward mean")
|
||||
rstd: not_implemented("native_layer_norm_backward rstd")
|
||||
|
||||
- name: _fused_rms_norm(Tensor input, int[] normalized_shape, Tensor? weight, float? eps) -> (Tensor, Tensor)
|
||||
input, weight: "GradMode::is_enabled() || grads[1].defined() ? infinitely_differentiable_native_rms_norm_backward(grads[0], grads[1], input, normalized_shape, result1, weight, grad_input_mask) : (grads[0].defined() ? _fused_rms_norm_backward(grads[0], input, normalized_shape, result1, weight, grad_input_mask) : std::tuple<Tensor, Tensor>())"
|
||||
result0: rms_norm_jvp(input_p, input_t, weight_p, weight_t, result1, normalized_shape)
|
||||
result1: rms_norm_rstd_jvp(input_p, input_t, result1, normalized_shape)
|
||||
|
||||
- name: native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)
|
||||
input, weight, bias: "GradMode::is_enabled() || grads[1].defined() || grads[2].defined() ? infinitely_differentiable_native_group_norm_backward(grads[0], grads[1], grads[2], input, result1, result2, weight, N, C, HxW, group, eps, grad_input_mask) : (grads[0].defined() ? native_group_norm_backward_symint(grads[0].device().is_xpu() ? grads[0] : grads[0].contiguous(grads[0].device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), input.device().is_xpu() ? input : input.contiguous(input.device().is_cpu() ? input.suggest_memory_format() : c10::MemoryFormat::Contiguous), result1, result2, weight, N, C, HxW, group, grad_input_mask) : std::tuple<Tensor, Tensor, Tensor>())"
|
||||
result0: group_norm_jvp(input_p, input_t, weight_p, weight_t, bias_p, bias_t, result1, result2, group)
|
||||
|
|
|
|||
|
|
@ -418,7 +418,6 @@ def _core_aten_decompositions_post_autograd() -> dict[
|
|||
aten.native_dropout_backward,
|
||||
aten.native_group_norm_backward,
|
||||
aten.native_layer_norm_backward,
|
||||
aten._fused_rms_norm_backward,
|
||||
aten.new_empty,
|
||||
aten.new_full,
|
||||
aten.new_ones,
|
||||
|
|
|
|||
|
|
@ -1743,81 +1743,6 @@ def native_layer_norm_backward_out(
|
|||
return grad_input
|
||||
|
||||
|
||||
@register_decomposition(aten._fused_rms_norm_backward.default)
|
||||
def _fused_rms_norm_backward(
|
||||
grad_out: Tensor,
|
||||
input: Tensor,
|
||||
normalized_shape: list[int],
|
||||
rstd: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor]]:
|
||||
input_shape = input.shape
|
||||
input_ndim = input.dim()
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
|
||||
grad_out_cast = grad_out.to(
|
||||
computation_dtype, memory_format=torch.contiguous_format
|
||||
)
|
||||
input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||
weight_cast = (
|
||||
weight.to(computation_dtype, memory_format=torch.contiguous_format)
|
||||
if weight is not None
|
||||
else None
|
||||
)
|
||||
assert grad_out_cast is not None
|
||||
|
||||
axis = input_ndim - len(normalized_shape)
|
||||
inner_dims = input_shape[axis:]
|
||||
outer_dims = input_shape[:axis]
|
||||
inner_dim_indices: list[int] = []
|
||||
outer_dim_indices: list[int] = []
|
||||
for i in range(input_ndim):
|
||||
if i >= axis:
|
||||
inner_dim_indices.append(i)
|
||||
else:
|
||||
outer_dim_indices.append(i)
|
||||
|
||||
N = prod(inner_dims) # type: ignore[arg-type]
|
||||
M = prod(outer_dims) # type: ignore[arg-type]
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0):
|
||||
return (
|
||||
input.new_zeros(input_shape) if output_mask[0] else None,
|
||||
input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
|
||||
)
|
||||
|
||||
rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
|
||||
if weight_cast is not None:
|
||||
grad_x_hat = grad_out_cast * weight_cast
|
||||
else:
|
||||
grad_x_hat = grad_out_cast
|
||||
|
||||
d_input: Optional[Tensor] = None
|
||||
d_weight: Optional[Tensor] = None
|
||||
|
||||
x_hat = input_cast * rstd
|
||||
|
||||
if output_mask[0]:
|
||||
sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
|
||||
d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
|
||||
|
||||
if output_mask[1] and weight_cast is not None:
|
||||
d_weight_full_shape = grad_out_cast * x_hat
|
||||
if len(outer_dim_indices) > 0:
|
||||
d_weight = torch.sum(
|
||||
d_weight_full_shape, dim=outer_dim_indices, keepdim=False
|
||||
)
|
||||
else:
|
||||
d_weight = d_weight_full_shape
|
||||
|
||||
return (
|
||||
_maybe_cast(d_input, input.dtype),
|
||||
_maybe_cast(d_weight, input.dtype),
|
||||
)
|
||||
|
||||
|
||||
def native_batch_norm_helper(
|
||||
input: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
|
|
|
|||
|
|
@ -5022,103 +5022,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||
return std::tuple<Tensor, Tensor, Tensor>{gI, gG, ggO};
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
|
||||
const Tensor& dY,
|
||||
const Tensor& drstd,
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const Tensor& rstd,
|
||||
const std::optional<Tensor>& weight_opt,
|
||||
std::array<bool, 2> grad_input_mask) {
|
||||
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
|
||||
at::borrow_from_optional_tensor(weight_opt);
|
||||
const Tensor& weight = *weight_maybe_owned;
|
||||
|
||||
const auto input_shape = input.sizes();
|
||||
const auto input_ndim = input.dim();
|
||||
const int normalized_ndim = normalized_shape.size();
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
|
||||
int64_t N_rms = 1;
|
||||
for (int i = 0; i < normalized_ndim; ++i) {
|
||||
N_rms *= input_shape[axis + i];
|
||||
}
|
||||
|
||||
Tensor dX;
|
||||
Tensor dgamma;
|
||||
|
||||
std::vector<int64_t> rstd_view_shape = rstd.sizes().vec();
|
||||
for (int i = 0;
|
||||
i < std::max(static_cast<int>(normalized_ndim - rstd.dim()), 0);
|
||||
++i) {
|
||||
rstd_view_shape.push_back(1);
|
||||
}
|
||||
Tensor rstd_broadcast = rstd.view(rstd_view_shape);
|
||||
Tensor rstd_pow3 = rstd_broadcast.pow(3);
|
||||
Tensor grad_x_hat;
|
||||
|
||||
if (dY.defined()) {
|
||||
if (weight.defined()) {
|
||||
grad_x_hat = dY * weight;
|
||||
} else {
|
||||
grad_x_hat = dY;
|
||||
}
|
||||
}
|
||||
|
||||
if (grad_input_mask[0]) {
|
||||
Tensor dX_from_dY_path;
|
||||
Tensor dX_from_drstd_path;
|
||||
|
||||
std::vector<int64_t> inner_sum_dims;
|
||||
inner_sum_dims.reserve(normalized_ndim);
|
||||
for (int i = 0; i < normalized_ndim; ++i) {
|
||||
inner_sum_dims.push_back(axis + i);
|
||||
}
|
||||
|
||||
if (dY.defined() && grad_x_hat.defined()) {
|
||||
Tensor sum_input_times_grad_x_hat =
|
||||
sum(input * grad_x_hat, inner_sum_dims, /*keepdim=*/true);
|
||||
dX_from_dY_path = rstd_broadcast * grad_x_hat -
|
||||
(input * rstd_pow3 / static_cast<double>(N_rms)) *
|
||||
sum_input_times_grad_x_hat;
|
||||
}
|
||||
|
||||
if (drstd.defined()) {
|
||||
Tensor drstd_broadcast = drstd.view(rstd_view_shape);
|
||||
dX_from_drstd_path =
|
||||
-(input * rstd_pow3 / static_cast<double>(N_rms)) * drstd_broadcast;
|
||||
}
|
||||
|
||||
if (dX_from_dY_path.defined() && dX_from_drstd_path.defined()) {
|
||||
dX = dX_from_dY_path + dX_from_drstd_path;
|
||||
} else if (dX_from_dY_path.defined()) {
|
||||
dX = dX_from_dY_path;
|
||||
} else if (dX_from_drstd_path.defined()) {
|
||||
dX = dX_from_drstd_path;
|
||||
}
|
||||
}
|
||||
|
||||
if (grad_input_mask[1] && weight.defined()) {
|
||||
if (dY.defined()) {
|
||||
Tensor x_hat = input * rstd_broadcast;
|
||||
Tensor dgamma_full_shape = dY * x_hat;
|
||||
|
||||
if (axis > 0) {
|
||||
std::vector<int64_t> outer_sum_dims;
|
||||
outer_sum_dims.reserve(axis);
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
outer_sum_dims.push_back(i);
|
||||
}
|
||||
dgamma = sum(dgamma_full_shape, outer_sum_dims, /*keepdim=*/false);
|
||||
} else {
|
||||
dgamma = dgamma_full_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(dX, dgamma);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor>
|
||||
infinitely_differentiable_native_group_norm_backward(
|
||||
const Tensor& dY,
|
||||
|
|
@ -6473,98 +6376,6 @@ Tensor layer_norm_jvp(
|
|||
bias_t.defined() ? bias_t.view(view_size_affine) : bias_t);
|
||||
}
|
||||
|
||||
Tensor rms_norm_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_p,
|
||||
const Tensor& weight_t,
|
||||
const Tensor& saved_rstd,
|
||||
IntArrayRef normalized_shape) {
|
||||
auto dims = std::vector<int64_t>{};
|
||||
auto view_size = input_t.sizes().vec();
|
||||
auto view_size_affine = input_t.sizes().vec();
|
||||
|
||||
int64_t numel = 1;
|
||||
for (const auto i : c10::irange(view_size.size())) {
|
||||
if (i < view_size.size() - normalized_shape.size()) {
|
||||
view_size_affine[i] = 1;
|
||||
} else {
|
||||
numel *= input_t.size(static_cast<int64_t>(i));
|
||||
view_size[i] = 1;
|
||||
dims.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
}
|
||||
|
||||
auto rstd_p = saved_rstd.view(view_size);
|
||||
|
||||
Tensor rstd_t;
|
||||
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||
input_t._is_zerotensor()) {
|
||||
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
|
||||
} else {
|
||||
rstd_t = input_t * input_p;
|
||||
rstd_t *= -rstd_p.pow(3);
|
||||
}
|
||||
rstd_t = rstd_t.sum(dims, true);
|
||||
rstd_t /= numel;
|
||||
|
||||
Tensor result_t;
|
||||
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||
input_t._is_zerotensor()) {
|
||||
result_t = (input_t)*rstd_p + (input_p)*rstd_t;
|
||||
} else {
|
||||
result_t = input_t * rstd_p;
|
||||
auto temp = input_p * rstd_t;
|
||||
result_t += temp;
|
||||
}
|
||||
|
||||
std::optional<Tensor> result_p = std::nullopt;
|
||||
if (weight_p.defined()) {
|
||||
result_p = std::optional<Tensor>(input_p * rstd_p);
|
||||
}
|
||||
|
||||
return _affine_jvp(
|
||||
result_p,
|
||||
result_t,
|
||||
weight_p.defined() ? weight_p.view(view_size_affine) : weight_p,
|
||||
weight_t.defined() ? weight_t.view(view_size_affine) : weight_t,
|
||||
Tensor());
|
||||
}
|
||||
|
||||
Tensor rms_norm_rstd_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
const Tensor& saved_rstd,
|
||||
IntArrayRef normalized_shape) {
|
||||
auto dims = std::vector<int64_t>{};
|
||||
auto view_size = input_t.sizes().vec();
|
||||
auto view_size_affine = input_t.sizes().vec();
|
||||
|
||||
int64_t numel = 1;
|
||||
for (const auto i : c10::irange(view_size.size())) {
|
||||
if (i < view_size.size() - normalized_shape.size()) {
|
||||
view_size_affine[i] = 1;
|
||||
} else {
|
||||
numel *= input_t.size(static_cast<int64_t>(i));
|
||||
view_size[i] = 1;
|
||||
dims.push_back(static_cast<int64_t>(i));
|
||||
}
|
||||
}
|
||||
|
||||
auto rstd_p = saved_rstd.view(view_size);
|
||||
Tensor rstd_t;
|
||||
if (areAnyTensorSubclassLike({input_t, input_p, rstd_p}) ||
|
||||
input_t._is_zerotensor()) {
|
||||
rstd_t = -rstd_p.pow(3) * (input_t) * (input_p);
|
||||
} else {
|
||||
rstd_t = input_t * input_p;
|
||||
rstd_t *= -rstd_p.pow(3);
|
||||
}
|
||||
rstd_t = rstd_t.sum(dims, true);
|
||||
rstd_t /= numel;
|
||||
return rstd_t;
|
||||
}
|
||||
|
||||
Tensor group_norm_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
|
|
|
|||
|
|
@ -828,15 +828,6 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_double_backward(
|
|||
c10::SymIntArrayRef normalized_shape,
|
||||
std::array<bool, 3> output_mask);
|
||||
|
||||
std::tuple<Tensor, Tensor> infinitely_differentiable_native_rms_norm_backward(
|
||||
const Tensor& dY,
|
||||
const Tensor& drstd,
|
||||
const Tensor& input,
|
||||
IntArrayRef normalized_shape,
|
||||
const Tensor& rstd,
|
||||
const std::optional<Tensor>& weight_opt,
|
||||
std::array<bool, 2> grad_input_mask);
|
||||
|
||||
std::tuple<Tensor, Tensor> householder_product_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& result,
|
||||
|
|
@ -976,20 +967,6 @@ Tensor layer_norm_jvp(
|
|||
const Tensor& saved_invstd,
|
||||
c10::SymIntArrayRef normalized_shape);
|
||||
|
||||
Tensor rms_norm_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
const Tensor& weight_p,
|
||||
const Tensor& weight_t,
|
||||
const Tensor& saved_rstd,
|
||||
IntArrayRef normalized_shape);
|
||||
|
||||
Tensor rms_norm_rstd_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
const Tensor& saved_rstd,
|
||||
IntArrayRef normalized_shape);
|
||||
|
||||
Tensor group_norm_jvp(
|
||||
const Tensor& input_p,
|
||||
const Tensor& input_t,
|
||||
|
|
|
|||
|
|
@ -820,7 +820,6 @@ def get_testing_overrides() -> dict[Callable, Callable]:
|
|||
torch._native_batch_norm_legit: lambda input, weight, bias, training, momentum, eps: -1,
|
||||
torch.native_dropout: lambda input, p, train: -1,
|
||||
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
|
||||
torch._fused_rms_norm: lambda input, normalized_shape, weight=None, eps=1e-05: -1,
|
||||
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
|
||||
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
|
||||
torch.native_channel_shuffle: lambda input, groups: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user