mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CPU] optimize Lp norm for 1-dimensional vector (#122143)
Fixes https://github.com/pytorch/pytorch/issues/120229 - Optimize vector norm by simplifying vector norm formula for 1-dimensional vector. - Vector norm formula for 1-dimensional vector simplifies to `abs(x)`. See below for proof. - Next step, we can similarly optimize matrix norm (`torch.linalg.matrix_norm`) for 1 x 1 matrix. - Additionally, avoids overflow in power, `abs(x) ** p` for large `p` or `x`, for 1-dimensional vector. ### Performance Avg Latency (ms) of `torch.norm` and `torch.linalg.vector_norm` for `torch.norm(torch.randn(2**18, 1), ord, -1)` `torch.linalg.vector_norm(torch.randn(2**18, 1), ord, -1)` Tested on 28 physical cores/socket, 1 socket on Skylake. | | | | | **Avg Latency (ms)** | | | |-------------------------- |----------------- |--------- |--------- |----------------------- |----------------------- |---------------------------------------- | | **op** | **input shape** | **dim** | **ord** | **baseline (master)** | **optimized (7102f1ef372b248414d36cbd0c51a546b6b6a41a)** | **speedup ratio (baseline/optimized)** | | torch.norm | (2**18, 1) | -1 | fro | 34.3755531 | 0.0125408 | 2741.094 | | | | | inf | 34.0952635 | 0.0122237 | 2789.271 | | | | | -inf | 34.3674493 | 0.0120759 | 2845.953 | | | | | 0 | 34.1004515 | 0.0175261 | 1945.69 | | | | | 1 | 34.1688442 | 0.0121593 | 2810.089 | | | | | -1 | 33.949492 | 0.0120282 | 2822.487 | | | | | 2 | 34.3669581 | 0.0120401 | 2854.366 | | | | | -2 | 33.9252067 | 0.0121069 | 2802.139 | | | | | | | | | | torch.linalg.vector_norm | (2**18, 1) | -1 | inf | 34.090879 | 0.0095105 | 3584.545 | | | | | -inf | 34.3708754 | 0.0099111 | 3467.931 | | | | | 0 | 34.0880775 | 0.0141716 | 2405.38 | | | | | 1 | 34.1392851 | 0.0093174 | 3664.036 | | | | | -1 | 33.925395 | 0.0092483 | 3668.302 | | | | | 2 | 34.3854165 | 0.0092459 | 3719.002 | | | | | -2 | 33.932972 | 0.0093007 | 3648.429 | ### Proof <details> <summary>For those interested :)</summary> <img width="382" alt="1_dim_vector_norm_proof1" src="https://github.com/pytorch/pytorch/assets/93151422/59b1e00b-8fcd-47cb-877d-d31403b5195b"> <img width="432" alt="1_dim_vector_norm_proof2" src="https://github.com/pytorch/pytorch/assets/93151422/236bea15-2dd5-480b-9871-58b2e3b24322"> </details> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122143 Approved by: https://github.com/lezcano
This commit is contained in:
parent
aa74a8b9e5
commit
057892f4be
|
|
@ -40,6 +40,7 @@
|
|||
#include <ATen/ops/_unsafe_view.h>
|
||||
#include <ATen/ops/_weight_int4pack_mm_native.h>
|
||||
#include <ATen/ops/_weight_int8pack_mm_native.h>
|
||||
#include <ATen/ops/abs.h>
|
||||
#include <ATen/ops/addbmm_native.h>
|
||||
#include <ATen/ops/addmm_native.h>
|
||||
#include <ATen/ops/addr.h>
|
||||
|
|
@ -120,6 +121,7 @@
|
|||
#include <ATen/ops/mul.h>
|
||||
#include <ATen/ops/mv.h>
|
||||
#include <ATen/ops/narrow.h>
|
||||
#include <ATen/ops/ne.h>
|
||||
#include <ATen/ops/norm.h>
|
||||
#include <ATen/ops/nuclear_norm_native.h>
|
||||
#include <ATen/ops/ones.h>
|
||||
|
|
@ -2815,6 +2817,36 @@ TORCH_IMPL_FUNC(linalg_vector_norm_out)(const Tensor& self, const Scalar& scalar
|
|||
// values larger than 10^53 (same for negative numbers), so that's fine.
|
||||
auto ord = scalar_ord.toDouble();
|
||||
auto dim = opt_dim.value_or(IntArrayRef{});
|
||||
auto size = self.sizes();
|
||||
auto ndim = self.dim();
|
||||
|
||||
auto opt_dim_ = dim.vec();
|
||||
maybe_wrap_dims(opt_dim_, ndim);
|
||||
|
||||
using Int = IntArrayRef::value_type;
|
||||
std::vector<Int> all_dim(ndim);
|
||||
std::iota(all_dim.begin(), all_dim.end(), 0);
|
||||
|
||||
bool is_all_reduce = !opt_dim.has_value() || opt_dim.value().empty();
|
||||
auto reduce_dim = is_all_reduce ? all_dim : opt_dim_;
|
||||
|
||||
bool is_reduce_over_1D_vector = true;
|
||||
for (auto i : reduce_dim) {
|
||||
if (size[i] != 1){
|
||||
is_reduce_over_1D_vector = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_reduce_over_1D_vector) {
|
||||
if (ord != 0.0) {
|
||||
keepdim ? at::abs_outf(self, const_cast<Tensor&>(result)) : at::abs_outf(self.squeeze(reduce_dim), const_cast<Tensor&>(result));
|
||||
} else {
|
||||
keepdim ? at::ne_outf(self, 0, const_cast<Tensor&>(result)) : at::ne_outf(self.squeeze(reduce_dim), 0, const_cast<Tensor&>(result));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// No need to handle opt_dtype explicitly as it is already encoded in the dtype of result
|
||||
|
||||
// https://github.com/pytorch/pytorch/issues/52648
|
||||
|
|
|
|||
|
|
@ -631,7 +631,7 @@ class TestForeach(TestCase):
|
|||
scaler = torch.tensor([max_value]).sqrt().to(device=device, dtype=dtype)
|
||||
inputs = ([
|
||||
t * scaler for t in next(iter(op.sample_inputs(device, dtype, requries_grad=True, num_input_tensors=[N], low=1))).input
|
||||
],)
|
||||
][:-1],)
|
||||
# make sure that the min. of squared L2 norm value per tensor is greater than the max value of `dtype`.
|
||||
self.assertTrue(scaler * scaler * N > max_value)
|
||||
fn, ref_fn, *_ = self._get_funcs(op)
|
||||
|
|
|
|||
|
|
@ -1631,6 +1631,29 @@ class TestLinalg(TestCase):
|
|||
result_n = np.linalg.norm(x_n, ord=ord)
|
||||
self.assertEqual(result, result_n, msg=msg)
|
||||
|
||||
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||
def test_vector_norm_reduce_over_1D_vector(self, device, dtype):
|
||||
input_sizes_and_dims = [
|
||||
((6, 1), -1),
|
||||
((3, 1, 2, 1), (1, 3)),
|
||||
((1,), None),
|
||||
]
|
||||
orders = [float('inf'), -float('inf'), 0, 1, -1, 2, -2]
|
||||
keepdims = [True, False]
|
||||
|
||||
for input_size_and_dim, ord, keepdim in product(input_sizes_and_dims, orders, keepdims):
|
||||
input_size = input_size_and_dim[0]
|
||||
dim = input_size_and_dim[1]
|
||||
if type(dim) is tuple and ord == 0:
|
||||
# skip because np.linalg.norm raises 'ValueError: Invalid norm order for matrices.'
|
||||
continue
|
||||
input = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9)
|
||||
result = torch.linalg.vector_norm(input, ord, dim, keepdim)
|
||||
result_numpy = np.linalg.norm(input.cpu().numpy(), ord, dim, keepdim)
|
||||
|
||||
msg = f'input.size()={input.size()}, ord={ord}, dim={dim}, keepdim={keepdim}, dtype={dtype}'
|
||||
self.assertEqual(result, result_numpy, msg=msg)
|
||||
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float, torch.double)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user