mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Vectorize arange (#38697)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38697 Benchmark (gcc 8.3, Debian Buster, turbo off, Release build, Intel(R) Xeon(R) E-2136, Parallelization using OpenMP): ```python import timeit for dtype in ('torch.double', 'torch.float', 'torch.uint8', 'torch.int8', 'torch.int16', 'torch.int32', 'torch.int64'): for n, t in [(40_000, 50000), (400_000, 5000)]: print(f'torch.arange(0, {n}, dtype={dtype}) for {t} times') print(timeit.timeit(f'torch.arange(0, {n}, dtype={dtype})', setup=f'import torch', number=t)) ``` Before: ``` torch.arange(0, 40000, dtype=torch.double) for 50000 times 1.587841397995362 torch.arange(0, 400000, dtype=torch.double) for 5000 times 0.47885190199303906 torch.arange(0, 40000, dtype=torch.float) for 50000 times 1.5519152240012772 torch.arange(0, 400000, dtype=torch.float) for 5000 times 0.4733216500026174 torch.arange(0, 40000, dtype=torch.uint8) for 50000 times 1.426058754004771 torch.arange(0, 400000, dtype=torch.uint8) for 5000 times 0.43596178699226584 torch.arange(0, 40000, dtype=torch.int8) for 50000 times 1.4289699140063021 torch.arange(0, 400000, dtype=torch.int8) for 5000 times 0.43451592899509706 torch.arange(0, 40000, dtype=torch.int16) for 50000 times 0.5714442400058033 torch.arange(0, 400000, dtype=torch.int16) for 5000 times 0.14837959500437137 torch.arange(0, 40000, dtype=torch.int32) for 50000 times 0.5964003179979045 torch.arange(0, 400000, dtype=torch.int32) for 5000 times 0.15676555599202402 torch.arange(0, 40000, dtype=torch.int64) for 50000 times 0.8390555799996946 torch.arange(0, 400000, dtype=torch.int64) for 5000 times 0.23184613398916554 ``` After: ``` torch.arange(0, 40000, dtype=torch.double) for 50000 times 0.6895066159922862 torch.arange(0, 400000, dtype=torch.double) for 5000 times 0.16820953000569716 torch.arange(0, 40000, dtype=torch.float) for 50000 times 1.3640095089940587 torch.arange(0, 400000, dtype=torch.float) for 5000 times 0.39255041000433266 torch.arange(0, 40000, dtype=torch.uint8) for 50000 times 0.3422072059911443 torch.arange(0, 400000, dtype=torch.uint8) for 5000 times 0.0605111670010956 torch.arange(0, 40000, dtype=torch.int8) for 50000 times 0.3449254590086639 torch.arange(0, 400000, dtype=torch.int8) for 5000 times 0.06115841199061833 torch.arange(0, 40000, dtype=torch.int16) for 50000 times 0.7745441729930462 torch.arange(0, 400000, dtype=torch.int16) for 5000 times 0.22106765500211623 torch.arange(0, 40000, dtype=torch.int32) for 50000 times 0.720475220005028 torch.arange(0, 400000, dtype=torch.int32) for 5000 times 0.20230313099455088 torch.arange(0, 40000, dtype=torch.int64) for 50000 times 0.8144655400101328 torch.arange(0, 400000, dtype=torch.int64) for 5000 times 0.23762561299372464 ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D22291236 Pulled By: VitalyFedyunin fbshipit-source-id: 134dd08b77b11e631d914b5500ee4285b5d0591e
This commit is contained in:
parent
fa6e900e8c
commit
34025eb826
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
namespace at { namespace native {
|
||||
|
||||
DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar, Scalar, Scalar), arange_stub);
|
||||
DECLARE_DISPATCH(void(*)(TensorIterator&, Scalar, Scalar, int64_t), linspace_stub);
|
||||
|
||||
Tensor& linspace_cpu_out(Tensor& result, Scalar start, Scalar end, int64_t steps) {
|
||||
|
|
@ -172,14 +173,8 @@ Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
|
|||
}
|
||||
|
||||
Tensor r = result.is_contiguous() ? result : result.contiguous();
|
||||
scalar_t *data_ptr = r.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, size, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
|
||||
scalar_t is = p_begin;
|
||||
for (int64_t i = p_begin; i < p_end; ++i, ++is) {
|
||||
data_ptr[i] = xstart + is * xstep;
|
||||
}
|
||||
});
|
||||
auto iter = TensorIterator::nullary_op(r, /*check_mem_overlap=*/true);
|
||||
arange_stub(iter.device_type(), iter, start, size, step);
|
||||
if (!result.is_contiguous()) {
|
||||
result.copy_(r);
|
||||
}
|
||||
|
|
@ -188,6 +183,7 @@ Tensor& arange_cpu_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
|
|||
return result;
|
||||
}
|
||||
|
||||
DEFINE_DISPATCH(arange_stub);
|
||||
DEFINE_DISPATCH(linspace_stub);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/Config.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cpu/vec256/vec256.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/cpu/Loops.h>
|
||||
|
|
@ -12,6 +13,30 @@ namespace {
|
|||
|
||||
using namespace vec256;
|
||||
|
||||
static void arange_kernel(TensorIterator& iter, Scalar scalar_start, Scalar scalar_steps, Scalar scalar_step) {
|
||||
AT_DISPATCH_ALL_TYPES(iter.dtype(), "arange_cpu", [&]() {
|
||||
using accscalar_t = at::acc_type<scalar_t, false>;
|
||||
auto start = scalar_start.to<accscalar_t>();
|
||||
auto steps = scalar_steps.to<accscalar_t>();
|
||||
auto step = scalar_step.to<accscalar_t>();
|
||||
at::parallel_for(0, steps, internal::GRAIN_SIZE, [&](int64_t p_begin, int64_t p_end) {
|
||||
int64_t idx(p_begin);
|
||||
TensorIterator it(iter);
|
||||
cpu_serial_kernel_vec(
|
||||
it,
|
||||
[start, step, &idx]() -> scalar_t {
|
||||
return start + step * (idx++);
|
||||
},
|
||||
[start, step, &idx]() -> Vec256<scalar_t> {
|
||||
Vec256<scalar_t> res;
|
||||
res = Vec256<scalar_t>::arange(start + step * idx, step);
|
||||
idx += Vec256<scalar_t>::size();
|
||||
return res;
|
||||
}, {p_begin, p_end});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
static void linspace_kernel(TensorIterator& iter, Scalar scalar_start, Scalar scalar_end, int64_t steps) {
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX(iter.dtype(), "linspace_cpu", [&]() {
|
||||
// step should be of double type for all integral types
|
||||
|
|
@ -50,6 +75,7 @@ static void linspace_kernel(TensorIterator& iter, Scalar scalar_start, Scalar sc
|
|||
|
||||
} // anonymous namespace
|
||||
|
||||
REGISTER_DISPATCH(arange_stub, &arange_kernel);
|
||||
REGISTER_DISPATCH(linspace_stub, &linspace_kernel);
|
||||
|
||||
}} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -1386,10 +1386,18 @@ class AbstractTestCases:
|
|||
self.assertEqual(len(w), 1)
|
||||
|
||||
def test_arange(self):
|
||||
res1 = torch.arange(0, 1)
|
||||
res = torch.tensor(range(10000))
|
||||
res1 = torch.arange(0, 10000) # Use a larger number so vectorized code can be triggered
|
||||
res2 = torch.tensor([], dtype=torch.int64)
|
||||
torch.arange(0, 1, out=res2)
|
||||
self.assertEqual(res1, res2, atol=0, rtol=0)
|
||||
torch.arange(0, 10000, out=res2)
|
||||
self.assertEqual(res, res1, atol=0, rtol=0)
|
||||
self.assertEqual(res, res2, atol=0, rtol=0)
|
||||
|
||||
# Vectorization on non-contiguous tensors
|
||||
res = torch.rand(3, 3, 300000).to(torch.int64)
|
||||
res = res.permute(2, 0, 1)
|
||||
torch.arange(0, 300000 * 3 * 3, out=res)
|
||||
self.assertEqual(res.flatten(), torch.arange(0, 300000 * 3 * 3))
|
||||
|
||||
# Check arange with only one argument
|
||||
res1 = torch.arange(10)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user