Add Half support for weight_norm on CPU (#148878)

Fixes #148867.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148878
Approved by: https://github.com/leslie-fang-intel, https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
CaoE 2025-04-08 01:12:29 +00:00 committed by PyTorch MergeBot
parent 5228986c39
commit d7f3cd0ac3
3 changed files with 46 additions and 37 deletions

View File

@ -53,8 +53,8 @@ std::tuple<Tensor,Tensor> weight_norm_cpu(
int64_t dim) {
auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'/'Half'
const auto dtype = (g.scalar_type() == at::ScalarType::BFloat16 || g.scalar_type() == at::ScalarType::Half) ?
at::ScalarType::Float : g.scalar_type();
auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
weight_norm_stub(kCPU, w, norm, v, g, dim);
@ -93,10 +93,7 @@ Tensor _weight_norm
auto v = v_in.contiguous();
auto g = g_in.contiguous();
auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
|| g.scalar_type() == at::ScalarType::Half;
bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
bool can_use_fused = (dim == 0) || (dim == v.dim() - 1);
if (can_use_fused) {
// weight_norm does not have a derivative defined for it, so this will route back through

View File

@ -48,7 +48,8 @@ void weight_norm_first_dim_kernel(
}
template <typename scalar_t>
inline void sum_norm_per_row(
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
sum_norm_per_row(
scalar_t* out_ptr,
const scalar_t* v_ptr,
int64_t size) {
@ -61,16 +62,18 @@ inline void sum_norm_per_row(
size);
}
inline void sum_norm_per_row(
template <typename scalar_t>
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
sum_norm_per_row(
float* out_ptr,
const BFloat16* v_ptr,
const scalar_t* v_ptr,
int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec v_bvec = bVec::loadu(v_ptr + d);
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0;
fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1;
@ -84,7 +87,8 @@ inline void sum_norm_per_row(
}
template <typename scalar_t>
inline void apply_norm_per_row(
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
apply_norm_per_row(
scalar_t* w_ptr,
const scalar_t* v_ptr,
const scalar_t* a_ptr,
@ -98,21 +102,23 @@ inline void apply_norm_per_row(
size);
}
inline void apply_norm_per_row(
BFloat16* w_ptr,
const BFloat16* v_ptr,
template <typename scalar_t>
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
apply_norm_per_row(
scalar_t* w_ptr,
const scalar_t* v_ptr,
const float* a_ptr,
int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec v_bvec = bVec::loadu(v_ptr + d);
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0;
fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1;
bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1);
bVec w_bvec = vec::convert_from_float<scalar_t>(w_fvec0, w_fvec1);
w_bvec.store(w_ptr + d);
}
for(; d < size; ++d) {
@ -222,7 +228,8 @@ void weight_norm_backward_first_dim_kernel(
}
template <typename scalar_t>
inline void sum_product_per_row(
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
sum_product_per_row(
scalar_t* out_ptr,
const scalar_t* grad_w_ptr,
const scalar_t* v_ptr,
@ -237,19 +244,21 @@ inline void sum_product_per_row(
size);
}
inline void sum_product_per_row(
template <typename scalar_t>
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
sum_product_per_row(
float* out_ptr,
const BFloat16* grad_w_ptr,
const BFloat16* v_ptr,
const scalar_t* grad_w_ptr,
const scalar_t* v_ptr,
int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float<scalar_t>(grad_w_bvec);
bVec v_bvec = bVec::loadu(v_ptr + d);
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0;
fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1;
@ -264,7 +273,8 @@ inline void sum_product_per_row(
}
template <typename scalar_t>
inline void apply_per_row_backward(
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
apply_per_row_backward(
scalar_t* grad_v_ptr,
const scalar_t* grad_w_ptr,
const scalar_t* v_ptr,
@ -282,26 +292,28 @@ inline void apply_per_row_backward(
size);
}
inline void apply_per_row_backward(
BFloat16* grad_v_ptr,
const BFloat16* grad_w_ptr,
const BFloat16* v_ptr,
template <typename scalar_t>
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
apply_per_row_backward(
scalar_t* grad_v_ptr,
const scalar_t* grad_w_ptr,
const scalar_t* v_ptr,
const float* a_ptr,
const float* b_ptr,
int64_t size) {
using bVec = vec::Vectorized<BFloat16>;
using bVec = vec::Vectorized<scalar_t>;
using fVec = vec::Vectorized<float>;
int64_t d = 0;
for (; d < size - (size % bVec::size()); d += bVec::size()) {
bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float<scalar_t>(grad_w_bvec);
bVec v_bvec = bVec::loadu(v_ptr + d);
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0;
fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1
- fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1;
bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1);
bVec grad_v_bvec = vec::convert_from_float<scalar_t>(grad_v_fvec0, grad_v_fvec1);
grad_v_bvec.store(grad_v_ptr + d);
}
for(; d < size; ++d) {
@ -395,7 +407,7 @@ void weight_norm_kernel(
int64_t dim) {
TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1,
"fused kernels can only be applied for first or last dim");
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, v.scalar_type(),
"weight_norm_kernel", [&]() {
using accscalar_t = at::opmath_type<scalar_t>;
if (dim == 0) {
@ -420,7 +432,7 @@ void weight_norm_backward_kernel(
int64_t dim) {
TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1,
"fused kernels can only be applied for first or last dim");
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, saved_v.scalar_type(),
"weight_norm_backward_kernel", [&]() {
using accscalar_t = at::opmath_type<scalar_t>;
if (dim == 0) {

View File

@ -1812,7 +1812,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
def test_weight_norm(self):
for dtype in [torch.float, torch.bfloat16]:
for dtype in [torch.float, torch.bfloat16, torch.float16]:
input = torch.randn(3, 4, dtype=dtype)
m = nn.Linear(4, 5).to(dtype=dtype)
expected_output = m(input)