mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Speed up an integer to the power of a positive integer on CPU (#26020)
Summary:
Current integer scalar exps are always cast to double. This commit avoids cast if the tensor is also
integral and the scalar is positive to speed up.
Benchmark (Debian Buster, g++ 8, Intel(R) Xeon(R) E-2136 CPU @ 3.30GHz 0 0:0 3300.00 MHz , Debug
build, Turbo turned off):
```python
import timeit
for n, t in [(1000, 13000),
(10_000, 1300)]:
for e in (2, 3, 4):
for dtype in ('torch.int16', 'torch.int32', 'torch.int64'):
print(f'a.pow({e}) (a.numel() == {n}) for {t} times')
print(f'dtype {dtype}, {t} times', end='\t\t')
print(timeit.timeit(f'a.pow({e})',
setup=f'import torch; a = torch.arange({n}, device="cpu", dtype={dtype})',
number=t))
```
Before:
```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.6958350749996498
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 0.7989626339999631
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 0.7973162800003593
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.8660746679997828
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 0.8101709959996697
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 0.8135280149999744
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 5.010833072999958
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 4.801007671999741
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 3.963344578000033
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.6216251330001796
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.5672429639998882
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.5544572270000572
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.656308512999658
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 1.502670819999821
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.5757876879997639
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 4.775718216999849
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 4.754745475000163
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 3.737249878000057
```
After:
```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.1006453190002503
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.0849009019998448
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.093259106000005
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.0859826279997833
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.1076840900000207
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.0755480369998622
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times 1.918211066999902
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times 1.9183043200000611
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times 1.930021430999659
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 0.7271483560002707
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.7289002070001516
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.7267536800000016
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 0.7301799359997858
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 0.7289195180001116
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 0.7270008230002531
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times 1.5354506029998447
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times 1.528263066999898
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times 1.5369428439998956
```
---
Best viewed with whitespace changes turned off
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26020
Differential Revision: D17485400
Pulled By: VitalyFedyunin
fbshipit-source-id: 3a16b074825a5aab0f7e7af3d8100f9e4b7011a3
This commit is contained in:
parent
66d27504e3
commit
ae0732cde3
|
|
@ -35,10 +35,8 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
|
|||
}
|
||||
|
||||
void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
|
||||
// Casting exponent to double(not tensor.dtype) allows powering integral
|
||||
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
|
||||
const auto exp = exp_scalar.to<double>();
|
||||
if (isFloatingType(iter.dtype())) {
|
||||
const auto exp = exp_scalar.to<double>();
|
||||
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
|
||||
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "pow", [&]() {
|
||||
using Vec = Vec256<scalar_t>;
|
||||
|
|
@ -98,55 +96,73 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
|
|||
// Trying to implement pow/sqrt/rsqrt as loop in vec256_int.h does not allow
|
||||
// powering integral tensor to float exponent. That's why we need this code
|
||||
// duplication:
|
||||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
|
||||
if (exp == 0.5) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return std::sqrt(static_cast<long double>(base));
|
||||
}
|
||||
);
|
||||
} else if (exp == 2) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
const auto ld_base = static_cast<long double>(base);
|
||||
return ld_base * ld_base;
|
||||
}
|
||||
);
|
||||
} else if (exp == 3) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
const auto ld_base = static_cast<long double>(base);
|
||||
return ld_base * ld_base * ld_base;
|
||||
}
|
||||
);
|
||||
} else if (exp == -0.5) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return 1.0 / std::sqrt(static_cast<long double>(base));
|
||||
}
|
||||
);
|
||||
} else if (exp == -1) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return 1.0 / static_cast<long double>(base);
|
||||
}
|
||||
);
|
||||
} else if (exp == -2) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
const auto ld_base = static_cast<long double>(base);
|
||||
return 1.0 / (ld_base * ld_base);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
cpu_kernel(iter,
|
||||
[=](scalar_t base) -> scalar_t {
|
||||
return std::pow(static_cast<long double>(base),
|
||||
static_cast<long double>(exp));
|
||||
}
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
if (exp_scalar.isIntegral(true) && exp_scalar.to<int64_t>() >= 0) {
|
||||
// Specifically deal with an integer to the power of a positive integer for better efficiency.
|
||||
const auto exp = exp_scalar.to<int64_t>();
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
|
||||
switch (exp) {
|
||||
case 2:
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return base * base;
|
||||
}
|
||||
);
|
||||
break;
|
||||
case 3:
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return base * base * base;
|
||||
}
|
||||
);
|
||||
break;
|
||||
default:
|
||||
cpu_kernel(iter,
|
||||
[=](scalar_t base) -> scalar_t {
|
||||
return std::pow(base, exp);
|
||||
}
|
||||
);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// Casting exponent to double(not tensor.dtype) allows powering integral
|
||||
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
|
||||
const auto exp = exp_scalar.to<double>();
|
||||
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
|
||||
if (exp == 0.5) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return std::sqrt(static_cast<long double>(base));
|
||||
}
|
||||
);
|
||||
} else if (exp == -0.5) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return 1.0 / std::sqrt(static_cast<long double>(base));
|
||||
}
|
||||
);
|
||||
} else if (exp == -1) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return 1.0 / static_cast<long double>(base);
|
||||
}
|
||||
);
|
||||
} else if (exp == -2) {
|
||||
cpu_kernel(iter,
|
||||
[](scalar_t base) -> scalar_t {
|
||||
return 1.0 / (base * base);
|
||||
}
|
||||
);
|
||||
} else {
|
||||
cpu_kernel(iter,
|
||||
[=](scalar_t base) -> scalar_t {
|
||||
return std::pow(static_cast<long double>(base), exp);
|
||||
}
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1347,51 +1347,6 @@ class _TestTorchMixin(object):
|
|||
res6 = torch.baddbmm(.1, res2, .5, b1, b2)
|
||||
self.assertEqual(res6, res2 * .1 + res * .5)
|
||||
|
||||
def test_pow(self):
|
||||
# [res] torch.pow([res,] x)
|
||||
|
||||
# pow has dedicated implementation for different exponents
|
||||
for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
|
||||
# base - tensor, exponent - number
|
||||
# contiguous
|
||||
m1 = torch.rand(100, 100) + 0.5
|
||||
res1 = torch.pow(m1[4], exponent)
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(m1[4][i], exponent)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# non-contiguous
|
||||
m1 = torch.rand(100, 100) + 0.5
|
||||
res1 = torch.pow(m1[:, 4], exponent)
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(m1[i, 4], exponent)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# base - number, exponent - tensor
|
||||
# contiguous
|
||||
m1 = torch.randn(100, 100)
|
||||
res1 = torch.pow(3, m1[4])
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(3, m1[4, i])
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# non-contiguous
|
||||
m1 = torch.randn(100, 100)
|
||||
res1 = torch.pow(3, m1[:, 4])
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(3, m1[i][4])
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# resize behavior for exp == 1
|
||||
m1 = torch.randn(2, 2)
|
||||
out = torch.randn([0])
|
||||
torch.pow(m1, 1, out=out)
|
||||
self.assertEqual(out, m1)
|
||||
|
||||
def _test_cop(self, torchfn, mathfn):
|
||||
def reference_implementation(res2):
|
||||
for i, j in iter_indices(sm1):
|
||||
|
|
@ -7022,6 +6977,66 @@ class TestTorchDeviceType(TestCase):
|
|||
expected = torch.diag(x, 17)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_pow(self, device):
|
||||
# [res] torch.pow([res,] x)
|
||||
|
||||
# pow has dedicated implementation for different exponents
|
||||
for dtype in torch.testing.get_all_math_dtypes(device):
|
||||
|
||||
# This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it
|
||||
# for now.
|
||||
if dtype == torch.half:
|
||||
continue
|
||||
|
||||
m1 = torch.empty(0, dtype=dtype, device=device)
|
||||
if m1.is_floating_point():
|
||||
m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5
|
||||
else:
|
||||
# math.pow will overflow and throw exceptions for large integers
|
||||
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
|
||||
m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device)
|
||||
|
||||
for num in [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]:
|
||||
if isinstance(num, int) and num < 0 and not m1.is_floating_point():
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r'Integers to negative integer powers are not allowed\.'):
|
||||
torch.pow(m1[4], num)
|
||||
else:
|
||||
# base - tensor, exponent - number
|
||||
# contiguous
|
||||
res1 = torch.pow(m1[4], num)
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(m1[4][i], num)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# non-contiguous
|
||||
res1 = torch.pow(m1[:, 4], num)
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(m1[i, 4], num)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# base - number, exponent - tensor
|
||||
# contiguous
|
||||
res1 = torch.pow(3, m1[4])
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(3, m1[4, i])
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# non-contiguous
|
||||
res1 = torch.pow(3, m1[:, 4])
|
||||
res2 = res1.clone().zero_()
|
||||
for i in range(res2.size(0)):
|
||||
res2[i] = math.pow(3, m1[i][4])
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
# resize behavior for exp == 1
|
||||
out = torch.zeros(1, dtype=dtype, device=device)
|
||||
torch.pow(m1, 1, out=out)
|
||||
self.assertEqual(out, m1)
|
||||
|
||||
def test_neg(self, device):
|
||||
int_types = [torch.int, torch.short, torch.int8, torch.uint8]
|
||||
float_types = [torch.float, torch.double, torch.long]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user