[CPU] optimize Lp norm for 1-dimensional vector (#122143)

Fixes https://github.com/pytorch/pytorch/issues/120229

- Optimize vector norm by simplifying vector norm formula for 1-dimensional vector.
- Vector norm formula for 1-dimensional vector simplifies to `abs(x)`. See below for proof.
- Next step, we can similarly optimize matrix norm (`torch.linalg.matrix_norm`) for 1 x 1 matrix.
- Additionally, avoids overflow in power, `abs(x) ** p` for large `p` or `x`, for 1-dimensional vector.

### Performance
Avg Latency (ms) of `torch.norm` and `torch.linalg.vector_norm` for
`torch.norm(torch.randn(2**18, 1), ord, -1)`
`torch.linalg.vector_norm(torch.randn(2**18, 1), ord, -1)`
Tested on 28 physical cores/socket, 1 socket on Skylake.

|                          	|                 	|         	|         	| **Avg Latency (ms)**  	|                       	|                                        	|
|--------------------------	|-----------------	|---------	|---------	|-----------------------	|-----------------------	|----------------------------------------	|
| **op**                   	| **input shape** 	| **dim** 	| **ord** 	| **baseline (master)** 	| **optimized (7102f1ef372b248414d36cbd0c51a546b6b6a41a)** 	| **speedup ratio (baseline/optimized)** 	|
| torch.norm               	| (2**18, 1)      	| -1      	| fro     	| 34.3755531            	| 0.0125408             	| 2741.094                               	|
|                          	|                 	|         	| inf     	| 34.0952635            	| 0.0122237             	| 2789.271                               	|
|                          	|                 	|         	| -inf    	| 34.3674493            	| 0.0120759             	| 2845.953                               	|
|                          	|                 	|         	| 0       	| 34.1004515            	| 0.0175261             	| 1945.69                                	|
|                          	|                 	|         	| 1       	| 34.1688442            	| 0.0121593             	| 2810.089                               	|
|                          	|                 	|         	| -1      	| 33.949492             	| 0.0120282             	| 2822.487                               	|
|                          	|                 	|         	| 2       	| 34.3669581            	| 0.0120401             	| 2854.366                               	|
|                          	|                 	|         	| -2      	| 33.9252067            	| 0.0121069             	| 2802.139                               	|
|                          	|                 	|         	|         	|                       	|                       	|                                        	|
| torch.linalg.vector_norm 	| (2**18, 1)      	| -1      	| inf     	| 34.090879             	| 0.0095105             	| 3584.545                               	|
|                          	|                 	|         	| -inf    	| 34.3708754            	| 0.0099111             	| 3467.931                               	|
|                          	|                 	|         	| 0       	| 34.0880775            	| 0.0141716             	| 2405.38                                	|
|                          	|                 	|         	| 1       	| 34.1392851            	| 0.0093174             	| 3664.036                               	|
|                          	|                 	|         	| -1      	| 33.925395             	| 0.0092483             	| 3668.302                               	|
|                          	|                 	|         	| 2       	| 34.3854165            	| 0.0092459             	| 3719.002                               	|
|                          	|                 	|         	| -2      	| 33.932972             	| 0.0093007             	| 3648.429                               	|

### Proof
<details>
<summary>For those interested :)</summary>

<img width="382" alt="1_dim_vector_norm_proof1" src="https://github.com/pytorch/pytorch/assets/93151422/59b1e00b-8fcd-47cb-877d-d31403b5195b">
<img width="432" alt="1_dim_vector_norm_proof2" src="https://github.com/pytorch/pytorch/assets/93151422/236bea15-2dd5-480b-9871-58b2e3b24322">

</details>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122143
Approved by: https://github.com/lezcano
This commit is contained in:
min-jean-cho 2024-03-20 23:20:20 +00:00 committed by PyTorch MergeBot
parent aa74a8b9e5
commit 057892f4be
3 changed files with 56 additions and 1 deletions

View File

@ -40,6 +40,7 @@
#include <ATen/ops/_unsafe_view.h>
#include <ATen/ops/_weight_int4pack_mm_native.h>
#include <ATen/ops/_weight_int8pack_mm_native.h>
#include <ATen/ops/abs.h>
#include <ATen/ops/addbmm_native.h>
#include <ATen/ops/addmm_native.h>
#include <ATen/ops/addr.h>
@ -120,6 +121,7 @@
#include <ATen/ops/mul.h>
#include <ATen/ops/mv.h>
#include <ATen/ops/narrow.h>
#include <ATen/ops/ne.h>
#include <ATen/ops/norm.h>
#include <ATen/ops/nuclear_norm_native.h>
#include <ATen/ops/ones.h>
@ -2815,6 +2817,36 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar
// values larger than 10^53 (same for negative numbers), so that's fine.
auto ord = scalar_ord.toDouble();
auto dim = opt_dim.value_or(IntArrayRef{});
auto size = self.sizes();
auto ndim = self.dim();
auto opt_dim_ = dim.vec();
maybe_wrap_dims(opt_dim_, ndim);
using Int = IntArrayRef::value_type;
std::vector<Int> all_dim(ndim);
std::iota(all_dim.begin(), all_dim.end(), 0);
bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
auto reduce_dim = is_all_reduce ? all_dim : opt_dim_;
bool is_reduce_over_1D_vector = true;
for (auto i : reduce_dim) {
if (size[i] != 1){
is_reduce_over_1D_vector = false;
break;
}
}
if (is_reduce_over_1D_vector) {
if (ord != 0.0) {
keepdim ? at::abs_outf(self, const_cast<Tensor&>(result)) : at::abs_outf(self.squeeze(reduce_dim), const_cast<Tensor&>(result));
} else {
keepdim ? at::ne_outf(self, 0, const_cast<Tensor&>(result)) : at::ne_outf(self.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
}
return;
}
// No need to handle opt_dtype explicitly as it is already encoded in the dtype of result
// https://github.com/pytorch/pytorch/issues/52648

View File

@ -631,7 +631,7 @@ class TestForeach(TestCase):
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
inputs = ([
t * scaler for t in next(iter(op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1))).input
],)
][:-1],)
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
self.assertTrue(scaler * scaler * N > max_value)
fn, ref_fn, *_ = self._get_funcs(op)

View File

@ -1631,6 +1631,29 @@ class TestLinalg(TestCase):
result_n = np.linalg.norm(x_n, ord=ord)
self.assertEqual(result, result_n, msg=msg)
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
input_sizes_and_dims = [
((6, 1), -1),
((3, 1, 2, 1), (1, 3)),
((1,), None),
]
orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
keepdims = [True, False]
for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
input_size = input_size_and_dim[0]
dim = input_size_and_dim[1]
if type(dim) is tuple and ord == 0:
# skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.'
continue
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
result = torch.linalg.vector_norm(input, ord, dim, keepdim)
result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
self.assertEqual(result, result_numpy, msg=msg)
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double)