diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 6dd9833f917..f638730a18f 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4065,8 +4065,8 @@ class CPUReproTests(TestCase): compiled_m = torch.compile(mod, dynamic=dynamic) actual, code = run_and_get_cpp_code(compiled_m, x) self.assertEqual(expected, actual) - # 2 generated kernels (one for var_mean, the other for result) - check_metrics_vec_kernel_count(2) + # 3 generated kernels (first one for var_mean, last two for result) + check_metrics_vec_kernel_count(3) # check loop split optimization if fmt == torch.channels_last: diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 94452ab4b79..696a7f649f8 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -3187,27 +3187,44 @@ def native_group_norm( + f"but got input of shape {input.shape} and num_groups = {num_groups}", ) + computation_dtype = utils.get_computation_dtype(input.dtype) + input_acc = _maybe_convert_to_dtype(input, computation_dtype) # num_channels / num_groups and flattened inner dimension are the reduction axes reduction_dims = [2, 3] input_reshaped = torch.reshape( - input, + input_acc, [batch_size, num_groups, num_channels // num_groups, flattened_inner_size], ) - out, mean, rstd = _normalize(input_reshaped, reduction_dims, eps) - out = out.view(input.shape) - - broadcast_dims = [0] + list(range(2, input.ndim)) - unsqueeze_bias = None - if bias is not None: - unsqueeze_bias = _unsqueeze_multiple(bias, broadcast_dims) - unsqueeze_weight = None + reduction_dims = utils.canonicalize_dims(input_reshaped.ndim, reduction_dims) + biased_var, mean = torch.var_mean( + input_reshaped, dim=reduction_dims, unbiased=False, keepdim=True + ) + rstd = torch.rsqrt(biased_var + eps) if weight is not None: - unsqueeze_weight = _unsqueeze_multiple(weight, broadcast_dims) + weight_reshaped = torch.reshape( + weight, [1, num_groups, num_channels // num_groups, 1] + ) + w = rstd * weight_reshaped + b = -mean * w + else: + w = rstd + b = -mean * rstd + if bias is not None: + bias_reshaped = torch.reshape( + bias, [1, num_groups, num_channels // num_groups, 1] + ) + b = b + bias_reshaped - if unsqueeze_weight is not None: - out = out * unsqueeze_weight - if unsqueeze_bias is not None: - out = out + unsqueeze_bias + if input.device.type == "cpu" and weight is not None: + w = w.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + b = b.contiguous().as_strided([batch_size, num_channels], [num_channels, 1]) + broadcast_dims = list(range(2, input.ndim)) + unsqueeze_w = _unsqueeze_multiple(w, broadcast_dims) + unsqueeze_b = _unsqueeze_multiple(b, broadcast_dims) + out = input_acc * unsqueeze_w + unsqueeze_b + else: + out = input_reshaped * w + b + out = out.view(input.shape) out = _maybe_convert_to_dtype(out, input.dtype) # type: ignore[assignment] mean = _maybe_convert_to_dtype(mean, input.dtype) # type: ignore[assignment]