Speed up an integer to the power of a positive integer on CPU (#26020)

Summary:
Current integer scalar exps are always cast to double. This commit avoids cast if the tensor is also
integral and the scalar is positive to speed up.

Benchmark (Debian Buster, g++ 8, Intel(R) Xeon(R) E-2136 CPU @ 3.30GHz	0	0:0	3300.00 MHz	, Debug
build, Turbo turned off):

```python
import timeit

for n, t in [(1000, 13000),
            (10_000, 1300)]:
    for e in (2, 3, 4):
        for dtype in ('torch.int16', 'torch.int32', 'torch.int64'):
            print(f'a.pow({e}) (a.numel() == {n}) for {t} times')
            print(f'dtype {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a.pow({e})',
                                setup=f'import torch; a = torch.arange({n}, device="cpu", dtype={dtype})',
                                number=t))
```

Before:

```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		1.6958350749996498
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		0.7989626339999631
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		0.7973162800003593
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		1.8660746679997828
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		0.8101709959996697
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		0.8135280149999744
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		5.010833072999958
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		4.801007671999741
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		3.963344578000033
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		1.6216251330001796
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		0.5672429639998882
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		0.5544572270000572
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		1.656308512999658
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		1.502670819999821
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		0.5757876879997639
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		4.775718216999849
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		4.754745475000163
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		3.737249878000057
```

After:

```
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		1.1006453190002503
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		1.0849009019998448
a.pow(2) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		1.093259106000005
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		1.0859826279997833
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		1.1076840900000207
a.pow(3) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		1.0755480369998622
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int16, 13000 times		1.918211066999902
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int32, 13000 times		1.9183043200000611
a.pow(4) (a.numel() == 1000) for 13000 times
dtype torch.int64, 13000 times		1.930021430999659
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		0.7271483560002707
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		0.7289002070001516
a.pow(2) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		0.7267536800000016
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		0.7301799359997858
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		0.7289195180001116
a.pow(3) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		0.7270008230002531
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int16, 1300 times		1.5354506029998447
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int32, 1300 times		1.528263066999898
a.pow(4) (a.numel() == 10000) for 1300 times
dtype torch.int64, 1300 times		1.5369428439998956
```

 ---

Best viewed with whitespace changes turned off
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26020

Differential Revision: D17485400

Pulled By: VitalyFedyunin

fbshipit-source-id: 3a16b074825a5aab0f7e7af3d8100f9e4b7011a3
This commit is contained in:
Hong Xu 2019-09-24 09:15:10 -07:00 committed by Facebook Github Bot
parent 66d27504e3
commit ae0732cde3
2 changed files with 128 additions and 97 deletions

View File

@ -35,10 +35,8 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
}
void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
// Casting exponent to double(not tensor.dtype) allows powering integral
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
const auto exp = exp_scalar.to<double>();
if (isFloatingType(iter.dtype())) {
const auto exp = exp_scalar.to<double>();
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
@ -98,55 +96,73 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
// Trying to implement pow/sqrt/rsqrt as loop in vec256_int.h does not allow
// powering integral tensor to float exponent. That's why we need this code
// duplication:
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
if (exp == 0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == 2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return ld_base * ld_base;
}
);
} else if (exp == 3) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return ld_base * ld_base * ld_base;
}
);
} else if (exp == -0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -1) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / static_cast<long double>(base);
}
);
} else if (exp == -2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return 1.0 / (ld_base * ld_base);
}
);
} else {
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(static_cast<long double>(base),
static_cast<long double>(exp));
}
);
}
});
if (exp_scalar.isIntegral(true) && exp_scalar.to<int64_t>() >= 0) {
// Specifically deal with an integer to the power of a positive integer for better efficiency.
const auto exp = exp_scalar.to<int64_t>();
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
switch (exp) {
case 2:
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return base * base;
}
);
break;
case 3:
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return base * base * base;
}
);
break;
default:
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, exp);
}
);
}
});
} else {
// Casting exponent to double(not tensor.dtype) allows powering integral
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
const auto exp = exp_scalar.to<double>();
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
if (exp == 0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -1) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / static_cast<long double>(base);
}
);
} else if (exp == -2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / (base * base);
}
);
} else {
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(static_cast<long double>(base), exp);
}
);
}
});
}
}
}

View File

@ -1347,51 +1347,6 @@ class _TestTorchMixin(object):
res6 = torch.baddbmm(.1, res2, .5, b1, b2)
self.assertEqual(res6, res2 * .1 + res * .5)
def test_pow(self):
# [res] torch.pow([res,] x)
# pow has dedicated implementation for different exponents
for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
# base - tensor, exponent - number
# contiguous
m1 = torch.rand(100, 100) + 0.5
res1 = torch.pow(m1[4], exponent)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[4][i], exponent)
self.assertEqual(res1, res2)
# non-contiguous
m1 = torch.rand(100, 100) + 0.5
res1 = torch.pow(m1[:, 4], exponent)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[i, 4], exponent)
self.assertEqual(res1, res2)
# base - number, exponent - tensor
# contiguous
m1 = torch.randn(100, 100)
res1 = torch.pow(3, m1[4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[4, i])
self.assertEqual(res1, res2)
# non-contiguous
m1 = torch.randn(100, 100)
res1 = torch.pow(3, m1[:, 4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[i][4])
self.assertEqual(res1, res2)
# resize behavior for exp == 1
m1 = torch.randn(2, 2)
out = torch.randn([0])
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)
def _test_cop(self, torchfn, mathfn):
def reference_implementation(res2):
for i, j in iter_indices(sm1):
@ -7022,6 +6977,66 @@ class TestTorchDeviceType(TestCase):
expected = torch.diag(x, 17)
self.assertEqual(result, expected)
def test_pow(self, device):
# [res] torch.pow([res,] x)
# pow has dedicated implementation for different exponents
for dtype in torch.testing.get_all_math_dtypes(device):
# This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it
# for now.
if dtype == torch.half:
continue
m1 = torch.empty(0, dtype=dtype, device=device)
if m1.is_floating_point():
m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5
else:
# math.pow will overflow and throw exceptions for large integers
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device)
for num in [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]:
if isinstance(num, int) and num < 0 and not m1.is_floating_point():
with self.assertRaisesRegex(RuntimeError,
r'Integers to negative integer powers are not allowed\.'):
torch.pow(m1[4], num)
else:
# base - tensor, exponent - number
# contiguous
res1 = torch.pow(m1[4], num)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[4][i], num)
self.assertEqual(res1, res2)
# non-contiguous
res1 = torch.pow(m1[:, 4], num)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[i, 4], num)
self.assertEqual(res1, res2)
# base - number, exponent - tensor
# contiguous
res1 = torch.pow(3, m1[4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[4, i])
self.assertEqual(res1, res2)
# non-contiguous
res1 = torch.pow(3, m1[:, 4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[i][4])
self.assertEqual(res1, res2)
# resize behavior for exp == 1
out = torch.zeros(1, dtype=dtype, device=device)
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)
def test_neg(self, device):
int_types = [torch.int, torch.short, torch.int8, torch.uint8]
float_types = [torch.float, torch.double, torch.long]