mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add Half support for weight_norm on CPU (#148878)
Fixes #148867. Pull Request resolved: https://github.com/pytorch/pytorch/pull/148878 Approved by: https://github.com/leslie-fang-intel, https://github.com/cyyever, https://github.com/albanD
This commit is contained in:
parent
5228986c39
commit
d7f3cd0ac3
|
|
@ -53,8 +53,8 @@ std::tuple<Tensor,Tensor> weight_norm_cpu(
|
|||
int64_t dim) {
|
||||
auto w = at::empty_like(v, at::MemoryFormat::Contiguous);
|
||||
|
||||
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'
|
||||
const auto dtype = g.scalar_type() == at::ScalarType::BFloat16 ?
|
||||
// align with cuda behavior, keep norm in 'Float' when g is 'BFloat16'/'Half'
|
||||
const auto dtype = (g.scalar_type() == at::ScalarType::BFloat16 || g.scalar_type() == at::ScalarType::Half) ?
|
||||
at::ScalarType::Float : g.scalar_type();
|
||||
auto norm = at::empty_strided(g.sizes(), g.strides(), g.options().dtype(dtype));
|
||||
weight_norm_stub(kCPU, w, norm, v, g, dim);
|
||||
|
|
@ -93,10 +93,7 @@ Tensor _weight_norm
|
|||
auto v = v_in.contiguous();
|
||||
auto g = g_in.contiguous();
|
||||
|
||||
auto has_half_dtype = v.scalar_type() == at::ScalarType::Half
|
||||
|| g.scalar_type() == at::ScalarType::Half;
|
||||
|
||||
bool can_use_fused = !has_half_dtype && ((dim == 0) || (dim == v.dim() - 1));
|
||||
bool can_use_fused = (dim == 0) || (dim == v.dim() - 1);
|
||||
|
||||
if (can_use_fused) {
|
||||
// weight_norm does not have a derivative defined for it, so this will route back through
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ void weight_norm_first_dim_kernel(
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void sum_norm_per_row(
|
||||
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
|
||||
sum_norm_per_row(
|
||||
scalar_t* out_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
int64_t size) {
|
||||
|
|
@ -61,16 +62,18 @@ inline void sum_norm_per_row(
|
|||
size);
|
||||
}
|
||||
|
||||
inline void sum_norm_per_row(
|
||||
template <typename scalar_t>
|
||||
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
|
||||
sum_norm_per_row(
|
||||
float* out_ptr,
|
||||
const BFloat16* v_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
int64_t size) {
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
||||
bVec v_bvec = bVec::loadu(v_ptr + d);
|
||||
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
|
||||
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
|
||||
|
||||
fVec out_fvec0 = fVec::loadu(out_ptr + d) + v_fvec0 * v_fvec0;
|
||||
fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + v_fvec1 * v_fvec1;
|
||||
|
|
@ -84,7 +87,8 @@ inline void sum_norm_per_row(
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void apply_norm_per_row(
|
||||
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
|
||||
apply_norm_per_row(
|
||||
scalar_t* w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
const scalar_t* a_ptr,
|
||||
|
|
@ -98,21 +102,23 @@ inline void apply_norm_per_row(
|
|||
size);
|
||||
}
|
||||
|
||||
inline void apply_norm_per_row(
|
||||
BFloat16* w_ptr,
|
||||
const BFloat16* v_ptr,
|
||||
template <typename scalar_t>
|
||||
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
|
||||
apply_norm_per_row(
|
||||
scalar_t* w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
const float* a_ptr,
|
||||
int64_t size) {
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
||||
bVec v_bvec = bVec::loadu(v_ptr + d);
|
||||
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
|
||||
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
|
||||
|
||||
fVec w_fvec0 = fVec::loadu(a_ptr + d) * v_fvec0;
|
||||
fVec w_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * v_fvec1;
|
||||
bVec w_bvec = convert_float_bfloat16(w_fvec0, w_fvec1);
|
||||
bVec w_bvec = vec::convert_from_float<scalar_t>(w_fvec0, w_fvec1);
|
||||
w_bvec.store(w_ptr + d);
|
||||
}
|
||||
for(; d < size; ++d) {
|
||||
|
|
@ -222,7 +228,8 @@ void weight_norm_backward_first_dim_kernel(
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void sum_product_per_row(
|
||||
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
|
||||
sum_product_per_row(
|
||||
scalar_t* out_ptr,
|
||||
const scalar_t* grad_w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
|
|
@ -237,19 +244,21 @@ inline void sum_product_per_row(
|
|||
size);
|
||||
}
|
||||
|
||||
inline void sum_product_per_row(
|
||||
template <typename scalar_t>
|
||||
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
|
||||
sum_product_per_row(
|
||||
float* out_ptr,
|
||||
const BFloat16* grad_w_ptr,
|
||||
const BFloat16* v_ptr,
|
||||
const scalar_t* grad_w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
int64_t size) {
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
||||
bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
|
||||
auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
|
||||
auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float<scalar_t>(grad_w_bvec);
|
||||
bVec v_bvec = bVec::loadu(v_ptr + d);
|
||||
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
|
||||
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
|
||||
|
||||
fVec out_fvec0 = fVec::loadu(out_ptr + d) + grad_w_fvec0 * v_fvec0;
|
||||
fVec out_fvec1 = fVec::loadu(out_ptr + d + fVec::size()) + grad_w_fvec1 * v_fvec1;
|
||||
|
|
@ -264,7 +273,8 @@ inline void sum_product_per_row(
|
|||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void apply_per_row_backward(
|
||||
inline std::enable_if_t<!is_reduced_floating_point_v<scalar_t>, void>
|
||||
apply_per_row_backward(
|
||||
scalar_t* grad_v_ptr,
|
||||
const scalar_t* grad_w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
|
|
@ -282,26 +292,28 @@ inline void apply_per_row_backward(
|
|||
size);
|
||||
}
|
||||
|
||||
inline void apply_per_row_backward(
|
||||
BFloat16* grad_v_ptr,
|
||||
const BFloat16* grad_w_ptr,
|
||||
const BFloat16* v_ptr,
|
||||
template <typename scalar_t>
|
||||
inline std::enable_if_t<is_reduced_floating_point_v<scalar_t>, void>
|
||||
apply_per_row_backward(
|
||||
scalar_t* grad_v_ptr,
|
||||
const scalar_t* grad_w_ptr,
|
||||
const scalar_t* v_ptr,
|
||||
const float* a_ptr,
|
||||
const float* b_ptr,
|
||||
int64_t size) {
|
||||
using bVec = vec::Vectorized<BFloat16>;
|
||||
using bVec = vec::Vectorized<scalar_t>;
|
||||
using fVec = vec::Vectorized<float>;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % bVec::size()); d += bVec::size()) {
|
||||
bVec grad_w_bvec = bVec::loadu(grad_w_ptr + d);
|
||||
auto [grad_w_fvec0, grad_w_fvec1] = convert_bfloat16_float(grad_w_bvec);
|
||||
auto [grad_w_fvec0, grad_w_fvec1] = vec::convert_to_float<scalar_t>(grad_w_bvec);
|
||||
bVec v_bvec = bVec::loadu(v_ptr + d);
|
||||
auto [v_fvec0, v_fvec1] = convert_bfloat16_float(v_bvec);
|
||||
auto [v_fvec0, v_fvec1] = vec::convert_to_float<scalar_t>(v_bvec);
|
||||
|
||||
fVec grad_v_fvec0 = fVec::loadu(a_ptr + d) * grad_w_fvec0 - fVec::loadu(b_ptr + d) * v_fvec0;
|
||||
fVec grad_v_fvec1 = fVec::loadu(a_ptr + d + fVec::size()) * grad_w_fvec1
|
||||
- fVec::loadu(b_ptr + d + fVec::size()) * v_fvec1;
|
||||
bVec grad_v_bvec = convert_float_bfloat16(grad_v_fvec0, grad_v_fvec1);
|
||||
bVec grad_v_bvec = vec::convert_from_float<scalar_t>(grad_v_fvec0, grad_v_fvec1);
|
||||
grad_v_bvec.store(grad_v_ptr + d);
|
||||
}
|
||||
for(; d < size; ++d) {
|
||||
|
|
@ -395,7 +407,7 @@ void weight_norm_kernel(
|
|||
int64_t dim) {
|
||||
TORCH_INTERNAL_ASSERT(dim == 0 || dim == v.dim() - 1,
|
||||
"fused kernels can only be applied for first or last dim");
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, v.scalar_type(),
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, v.scalar_type(),
|
||||
"weight_norm_kernel", [&]() {
|
||||
using accscalar_t = at::opmath_type<scalar_t>;
|
||||
if (dim == 0) {
|
||||
|
|
@ -420,7 +432,7 @@ void weight_norm_backward_kernel(
|
|||
int64_t dim) {
|
||||
TORCH_INTERNAL_ASSERT(dim == 0 || dim == saved_v.dim() - 1,
|
||||
"fused kernels can only be applied for first or last dim");
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, saved_v.scalar_type(),
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, saved_v.scalar_type(),
|
||||
"weight_norm_backward_kernel", [&]() {
|
||||
using accscalar_t = at::opmath_type<scalar_t>;
|
||||
if (dim == 0) {
|
||||
|
|
|
|||
|
|
@ -1812,7 +1812,7 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
|
||||
|
||||
def test_weight_norm(self):
|
||||
for dtype in [torch.float, torch.bfloat16]:
|
||||
for dtype in [torch.float, torch.bfloat16, torch.float16]:
|
||||
input = torch.randn(3, 4, dtype=dtype)
|
||||
m = nn.Linear(4, 5).to(dtype=dtype)
|
||||
expected_output = m(input)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user