Vectorize on output for reduction kernels (#37206)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37206

Benchmark on P100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark-vectorize-output.ipynb

```python
import torch
print(torch.__version__)
print()

for i in range(1000):
    torch.arange(10000, device='cuda')

def benchmark(dtype, i):
    size0 = 2 ** (i // 2)
    size1 = 2 ** ((i + 1) // 2)
    a = torch.zeros(size0, size1, device='cuda', dtype=dtype)
    torch.cuda.synchronize()
    %timeit a.sum(dtype=dtype, dim=0); torch.cuda.synchronize()

for dtype in [torch.int8, torch.half, torch.float, torch.double]:
    print(dtype)
    for i in range(18, 30):
        benchmark(dtype, i)
    print()
```
Before
```
1.5.0a0+3bbb36e

torch.int8
24.5 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.1 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.9 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
59.6 µs ± 244 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
111 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
186 µs ± 300 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
397 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
665 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.45 ms ± 837 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.03 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

torch.float16
24.2 µs ± 66.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.6 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
27.2 µs ± 53.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32 µs ± 91 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
48.1 µs ± 89.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66.9 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
121 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
218 µs ± 384 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
431 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
854 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.75 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.63 ms ± 849 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)

torch.float32
24.2 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
24.4 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.3 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
40.5 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
57.4 µs ± 44.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
85.5 µs ± 41.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
288 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
557 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1e+03 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.98 ms ± 533 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.8 ms ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

torch.float64
25 µs ± 54.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
26.9 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
37.1 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
54.3 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
84.9 µs ± 65.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
139 µs ± 68.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
275 µs ± 235 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
504 µs ± 702 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
987 µs ± 613 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.84 ms ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.64 ms ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.19 ms ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```
After
```
1.5.0a0+3bbb36e

torch.int8
29.8 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.7 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33.4 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
32.5 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
40.6 µs ± 94.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
53.7 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
68 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
98.2 µs ± 88.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
283 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
522 µs ± 563 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
967 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

torch.float16
29.4 µs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.2 µs ± 45.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
30.8 µs ± 41 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
35.3 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
50.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
70.4 µs ± 67.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
101 µs ± 325 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
157 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
275 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
486 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
936 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.85 ms ± 124 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

torch.float32
29.9 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
29.5 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
33 µs ± 93.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46 µs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
64 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
99.4 µs ± 82.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
157 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
265 µs ± 68.8 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
490 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
960 µs ± 669 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.84 ms ± 632 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.6 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

torch.float64
33.1 µs ± 74.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
36.7 µs ± 86.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.7 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
61.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
100 µs ± 23.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
158 µs ± 202 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
270 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
491 µs ± 445 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
939 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1.88 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
3.65 ms ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.3 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
```

Test Plan: Imported from OSS

Differential Revision: D21233255

Pulled By: ngimel

fbshipit-source-id: d468fddbb228c0c13146dfc6344c470513f9e374
This commit is contained in:
Xiang Gao 2020-06-11 19:41:55 -07:00 committed by Facebook GitHub Bot
parent a92231b70e
commit b10c53e9b8
4 changed files with 314 additions and 138 deletions

View File

@ -8,7 +8,7 @@
namespace at { namespace detail {
template <typename T, int size>
struct alignas(16) Array {
struct Array {
T data[size];
C10_HOST_DEVICE T operator[](int i) const {

View File

@ -39,6 +39,8 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
out << config.output_mult[i];
}
out << "], ";
out << "vectorize_input=" << config.vectorize_input << ", ";
out << "output_vec_size=" << config.output_vec_size << ", ";
out << "block_width=" << config.block_width << ", ";
out << "block_height=" << config.block_height << ", ";
out << "num_threads=" << config.num_threads << ", ";

View File

@ -72,7 +72,7 @@ struct ReduceConfig {
static constexpr int CTA = 2;
static constexpr int MAX_NUM_THREADS = 512;
static constexpr int vec_size = 4;
static constexpr int input_vec_size = 4;
ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
: element_size_bytes(element_size_bytes)
@ -92,14 +92,16 @@ struct ReduceConfig {
int block_height;
int num_threads;
bool vectorize = false;
bool vectorize_input = false;
int output_vec_size = 1;
void set_block_dimension(int64_t dim0, int64_t dim1) {
int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim0)) : MAX_NUM_THREADS;
int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim1)) : MAX_NUM_THREADS;
const int max_num_threads = MAX_NUM_THREADS / output_vec_size;
int dim0_pow2 = dim0 < max_num_threads ? static_cast<int>(last_pow2(dim0)) : max_num_threads;
int dim1_pow2 = dim1 < max_num_threads ? static_cast<int>(last_pow2(dim1)) : max_num_threads;
block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
block_height = std::min(dim1_pow2, int(MAX_NUM_THREADS / block_width));
block_width = std::min(dim0_pow2, int(MAX_NUM_THREADS / block_height));
block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
num_threads = block_width * block_height;
}
@ -120,7 +122,7 @@ struct ReduceConfig {
}
dim3 grid() const {
return dim3(div_up(num_outputs, step_output), ctas_per_output);
return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
}
C10_HOST_DEVICE bool should_block_x_reduce() const {
@ -155,13 +157,14 @@ struct ReduceConfig {
cta2 * input_mult[CTA]);
}
template <int output_vec_size>
C10_HOST_DEVICE int output_idx() const {
int lane = threadIdx.x;
int warp = threadIdx.y;
int cta1 = blockIdx.x;
return (lane * output_mult[BLOCK_X] +
warp * output_mult[BLOCK_Y] +
cta1 * step_output);
cta1 * step_output) * output_vec_size;
}
C10_DEVICE int shared_memory_offset(int offset) const {
@ -182,7 +185,7 @@ struct ReduceConfig {
block_width <= at::cuda::warp_size())) {
return 0;
}
return element_size_bytes * num_threads;
return element_size_bytes * num_threads * output_vec_size;
}
int64_t global_memory_size() const {
@ -191,7 +194,7 @@ struct ReduceConfig {
}
auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
if (!should_block_x_reduce()) {
size *= block().x;
size *= block().x * output_vec_size;
}
return size;
}
@ -210,10 +213,10 @@ struct ReduceConfig {
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
template<int nt, typename R>
template<int nt, int output_vec_size, typename R>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void reduce_kernel(R reduction) {
reduction.run();
reduction.template run<output_vec_size>();
}
template <typename index_t>
@ -285,7 +288,7 @@ struct ReduceOp {
static constexpr float acc_buffer_multiplier = (float)sizeof(arg_t) / sizeof(out_scalar_t);
static constexpr int vec_size = ReduceConfig::vec_size;
static constexpr int input_vec_size = ReduceConfig::input_vec_size;
ops_t ops;
arg_t ident;
@ -336,57 +339,78 @@ struct ReduceOp {
}
}
template <int output_vec_size>
C10_DEVICE void run() const {
extern __shared__ char shared_memory[];
index_t output_idx = config.output_idx();
index_t output_idx = config.output_idx<output_vec_size>();
index_t input_idx = config.input_idx();
auto base_offsets = output_calc.get(output_idx);
auto base_offsets1 = output_calc.get(output_idx)[1];
arg_t value = ident;
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
arg_vec_t value;
if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets[1]);
value = thread_reduce(input_slice);
const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
value = thread_reduce<output_vec_size>(input_slice);
}
if (config.should_block_y_reduce()) {
value = block_y_reduce(value, shared_memory);
value = block_y_reduce<output_vec_size>(value, shared_memory);
}
if (config.should_block_x_reduce()) {
value = block_x_reduce(value, shared_memory);
value = block_x_reduce<output_vec_size>(value, shared_memory);
}
auto out = (out_scalar_t*)((char*)dst[0] + base_offsets[0]);
arg_t* acc = nullptr;
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
offset_vec_t base_offsets;
out_ptr_vec_t out;
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
base_offsets[i] = output_calc.get(output_idx + i)[0];
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
}
arg_vec_t* acc = nullptr;
if (acc_buf != nullptr) {
size_t numerator = sizeof(arg_t);
size_t denominator = sizeof(out_scalar_t);
reduce_fraction(numerator, denominator);
acc = (arg_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
}
if (config.should_global_reduce()) {
value = global_reduce(value, acc, shared_memory);
value = global_reduce<output_vec_size>(value, acc, shared_memory);
} else if (config.should_store(output_idx)) {
if (accumulate) {
value = ops.translate_idx(value, base_idx);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.translate_idx(value[i], base_idx);
}
}
if (acc == nullptr) {
if (accumulate) {
value = accumulate_in_output<can_accumulate_in_output>(out, value);
value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
}
if (final_output) {
set_results_to_output(value, base_offsets[0]);
set_results_to_output<output_vec_size>(value, base_offsets);
} else {
*out = get_accumulated_output<can_accumulate_in_output>(out, value);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
*(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
}
}
} else {
if (accumulate) {
value = ops.combine(*acc, value);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine((*acc)[i], value[i]);
}
}
if (final_output) {
set_results_to_output(value, base_offsets[0]);
set_results_to_output<output_vec_size>(value, base_offsets);
} else {
*acc = value;
}
@ -394,30 +418,32 @@ struct ReduceOp {
}
}
C10_DEVICE arg_t thread_reduce(const scalar_t* data) const {
if (config.vectorize) {
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
if (config.vectorize_input) {
assert(output_vec_size == 1);
// reduce at the header of input_slice where memory is not aligned,
// so that thread_reduce will have an aligned memory to work on.
return vectorized_thread_reduce_impl(data);
return {input_vectorized_thread_reduce_impl(data)};
} else {
index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
if (is_contiguous) {
return thread_reduce_impl(data, [](index_t idx) { return idx; });
return thread_reduce_impl<output_vec_size>(data, [](index_t idx) { return idx; });
} else if (input_calc.dims == 1) {
return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; });
return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return idx * element_stride; });
} else {
return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
}
}
}
C10_DEVICE arg_t vectorized_thread_reduce_impl(const scalar_t* data) const {
C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
index_t end = config.num_inputs;
// Handle the head of input slice where data is not aligned
arg_t value = ident;
constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, vec_size>);
constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, input_vec_size>);
constexpr int align_elements = align_bytes / sizeof(scalar_t);
int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
if (shift > 0) {
@ -432,41 +458,41 @@ struct ReduceOp {
}
// Do the vectorized reduction
using load_t = at::native::memory::aligned_vector<scalar_t, vec_size>;
using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;
index_t idx = config.input_idx();
const index_t stride = config.step_input;
// Multiple accumulators to remove dependency between unrolled loops.
#ifndef __HIP_PLATFORM_HCC__
arg_t value_list[vec_size];
arg_t value_list[input_vec_size];
#else
ROCm_Bug<arg_t, vec_size> value_list;
ROCm_Bug<arg_t, input_vec_size> value_list;
#endif
value_list[0] = value;
#pragma unroll
for (int i = 1; i < vec_size; i++) {
for (int i = 1; i < input_vec_size; i++) {
value_list[i] = ident;
}
#ifndef __HIP_PLATFORM_HCC__
scalar_t values[vec_size];
scalar_t values[input_vec_size];
#else
ROCm_Bug<scalar_t, vec_size> values;
ROCm_Bug<scalar_t, input_vec_size> values;
#endif
load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
while (idx * vec_size + vec_size - 1 < end) {
while (idx * input_vec_size + input_vec_size - 1 < end) {
*values_vector = reinterpret_cast<const load_t*>(data)[idx];
#pragma unroll
for (index_t i = 0; i < vec_size; i++) {
value_list[i] = ops.reduce(value_list[i], values[i], shift + idx * vec_size + i);
for (index_t i = 0; i < input_vec_size; i++) {
value_list[i] = ops.reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
}
idx += stride;
}
// tail
index_t tail_start = end - end % vec_size;
index_t tail_start = end - end % input_vec_size;
if (config.should_reduce_tail()) {
int idx = tail_start + threadIdx.x;
if (idx < end) {
@ -476,45 +502,54 @@ struct ReduceOp {
// combine accumulators
#pragma unroll
for (int i = 1; i < vec_size; i++) {
for (int i = 1; i < input_vec_size; i++) {
value_list[0] = ops.combine(value_list[0], value_list[i]);
}
return value_list[0];
}
template<typename offset_calc_t>
C10_DEVICE arg_t thread_reduce_impl(const scalar_t* data, offset_calc_t calc) const {
template <int output_vec_size, typename offset_calc_t>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
index_t idx = config.input_idx();
const index_t end = config.num_inputs;
const index_t stride = config.step_input;
// Multiple accumulators to remove dependency between unrolled loops.
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
const load_t* data = reinterpret_cast<const load_t*>(data_);
// Multiple accumulators to remove dependency between unrolled loops.
#ifndef __HIP_PLATFORM_HCC__
arg_t value_list[vt0];
arg_vec_t value_list[vt0];
#else
ROCm_Bug<arg_t, vt0> value_list;
ROCm_Bug<arg_vec_t, vt0> value_list;
#endif
#pragma unroll
for (int i = 0; i < vt0; i++) {
value_list[i] = ident;
#pragma unroll
for (int j = 0; j < output_vec_size; j++) {
value_list[i][j] = ident;
}
}
#ifndef __HIP_PLATFORM_HCC__
scalar_t values[vt0];
load_t values[vt0];
#else
ROCm_Bug<scalar_t, vt0> values;
ROCm_Bug<load_t, vt0> values;
#endif
while (idx + (vt0 - 1) * stride < end) {
#pragma unroll
for (index_t i = 0; i < vt0; i++) {
values[i] = data[calc(idx + i * stride)];
values[i] = data[calc(idx + i * stride) / output_vec_size];
}
#pragma unroll
for (index_t i = 0; i < vt0; i++) {
value_list[i] = ops.reduce(value_list[i], values[i], idx + i * stride);
#pragma unroll
for (index_t j = 0; j < output_vec_size; j++) {
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
}
}
idx += stride * vt0;
}
@ -526,7 +561,7 @@ struct ReduceOp {
if (idx >= end) {
break;
}
values[i] = data[calc(idx)];
values[i] = data[calc(idx) / output_vec_size];
idx += stride;
}
idx = idx_;
@ -535,29 +570,40 @@ struct ReduceOp {
if (idx >= end) {
break;
}
value_list[i] = ops.reduce(value_list[i], values[i], idx);
#pragma unroll
for (index_t j = 0; j < output_vec_size; j++) {
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
}
idx += stride;
}
// combine accumulators
#pragma unroll
for (int i = 1; i < vt0; i++) {
value_list[0] = ops.combine(value_list[0], value_list[i]);
#pragma unroll
for (index_t j = 0; j < output_vec_size; j++) {
value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
}
}
return value_list[0];
}
C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const {
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
int dim_x = blockDim.x;
arg_t* shared = (arg_t*)shared_memory;
args_vec_t* shared = (args_vec_t*)shared_memory;
if (dim_x > warpSize) {
int address_base = threadIdx.x + threadIdx.y*blockDim.x;
shared[address_base] = value;
for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
__syncthreads();
if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
arg_t other = shared[address_base + offset];
value = ops.combine(value, other);
args_vec_t other = shared[address_base + offset];
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine(value[i], other[i]);
}
shared[address_base] = value;
}
}
@ -567,20 +613,28 @@ struct ReduceOp {
__syncthreads();
for (int offset = 1; offset < dim_x; offset <<= 1) {
arg_t other = ops.warp_shfl_down(value, offset);
value = ops.combine(value, other);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);
value[i] = ops.combine(value[i], other);
}
}
return value;
}
C10_DEVICE arg_t block_y_reduce(arg_t value, char* shared_memory) const {
arg_t* shared = (arg_t*)shared_memory;
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
args_vec_t* shared = (args_vec_t*)shared_memory;
shared[config.shared_memory_offset(0)] = value;
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
__syncthreads();
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
arg_t other = shared[config.shared_memory_offset(offset)];
value = ops.combine(value, other);
args_vec_t other = shared[config.shared_memory_offset(offset)];
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine(value[i], other[i]);
}
shared[config.shared_memory_offset(0)] = value;
}
}
@ -601,12 +655,18 @@ struct ReduceOp {
return is_last_block_done_shared;
}
template <bool can_acc>
C10_DEVICE arg_t accumulate_in_output(
out_scalar_t* out, arg_t value,
template <int output_vec_size, bool can_acc>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
at::detail::Array<out_scalar_t*, output_vec_size> out,
at::detail::Array<arg_t, output_vec_size> value,
typename std::enable_if<can_acc>::type* = nullptr
) const {
return ops.combine(*out, value);
at::detail::Array<arg_t, output_vec_size> ret;
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
ret[i] = ops.combine(*(out[i]), value[i]);
}
return ret;
}
template <bool can_acc>
@ -621,9 +681,10 @@ struct ReduceOp {
// This function should never be called --
// it's the version of `accumulate_in_output`
// when accumulation in the output is not possible.
template <bool can_acc>
C10_DEVICE arg_t accumulate_in_output(
out_scalar_t*, arg_t,
template <int output_vec_size, bool can_acc>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
at::detail::Array<out_scalar_t*, output_vec_size>,
at::detail::Array<arg_t, output_vec_size>,
typename std::enable_if<!can_acc>::type* = nullptr
) const {
assert(false); // can't use AT_ASSERT in Cuda.
@ -664,18 +725,33 @@ struct ReduceOp {
}
}
C10_DEVICE void set_results_to_output(arg_t value, index_t base_offset) const {
template <int output_vec_size>
C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
assert(final_output);
set_results(ops.project(value), base_offset);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
set_results(ops.project(value[i]), base_offset[i]);
}
}
C10_DEVICE arg_t global_reduce(arg_t value, arg_t* acc, char* shared_memory) const {
arg_t* reduce_buffer = (arg_t*)cta_buf;
index_t output_idx = config.output_idx();
auto base_offsets = output_calc.get(output_idx);
auto out = (out_scalar_t*)((char*)dst[0] + base_offsets[0]);
template <int output_vec_size>
C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
bool should_store = config.should_store(config.output_idx());
arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
index_t output_idx = config.output_idx<output_vec_size>();
offset_vec_t base_offsets;
out_ptr_vec_t out;
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
base_offsets[i] = output_calc.get(output_idx + i)[0];
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
}
bool should_store = config.should_store(output_idx);
if (should_store) {
index_t offset = config.staging_memory_offset(blockIdx.y);
reduce_buffer[offset] = value;
@ -692,42 +768,57 @@ struct ReduceOp {
index_t step = blockDim.x * blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
index_t idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
value = ops.combine(value, next);
arg_vec_t next = reduce_buffer[idx];
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine(value[i], next[i]);
}
}
} else {
index_t input_offset = threadIdx.y;
index_t step = blockDim.y;
for (; input_offset < config.ctas_per_output; input_offset += step) {
index_t idx = config.staging_memory_offset(input_offset);
arg_t next = reduce_buffer[idx];
value = ops.combine(value, next);
arg_vec_t next = reduce_buffer[idx];
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine(value[i], next[i]);
}
}
}
value = block_y_reduce(value, shared_memory);
if (config.should_block_x_reduce()) {
value = block_x_reduce(value, shared_memory);
value = block_x_reduce<output_vec_size>(value, shared_memory);
}
if (should_store) {
if (accumulate) {
value = ops.translate_idx(value, base_idx);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.translate_idx(value[i], base_idx);
}
}
if (acc == nullptr) {
if (accumulate) {
value = accumulate_in_output<can_accumulate_in_output>(out, value);
value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
}
if (final_output) {
set_results_to_output(value, base_offsets[0]);
set_results_to_output<output_vec_size>(value, base_offsets);
} else {
*out = get_accumulated_output<can_accumulate_in_output>(out, value);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
*(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
}
}
} else {
if (accumulate) {
value = ops.combine(*acc, value);
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
value[i] = ops.combine((*acc)[i], value[i]);
}
}
if (final_output) {
set_results_to_output(value, base_offsets[0]);
set_results_to_output<output_vec_size>(value, base_offsets);
} else {
*acc = value;
}
@ -739,14 +830,25 @@ struct ReduceOp {
}
};
template<int nt, typename R>
template<int max_threads, typename R>
static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
dim3 block = config.block();
dim3 grid = config.grid();
auto stream = at::cuda::getCurrentCUDAStream();
int shared_memory = config.shared_memory_size();
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>>(reduction);
switch(config.output_vec_size) {
case 4:
reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
break;
case 2:
reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
break;
default:
reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
}
AT_CUDA_CHECK(cudaGetLastError());
}
@ -786,6 +888,31 @@ class AccumulationBuffer {
at::DataPtr buffer_;
};
template <typename scalar_t>
int get_output_vec_size(TensorIterator &iter) {
int vec_size = 4;
auto update_vec_size = [&vec_size](uint64_t n) {
while(n % vec_size != 0) {
vec_size /= 2;
}
};
uint64_t base_address = reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
update_vec_size(base_address);
const int output_index = iter.num_reduce_dims();
update_vec_size(iter.shape()[output_index]);
int j = 0;
for(auto i : iter.strides(iter.noutputs())) {
if (j != output_index) {
update_vec_size(i / sizeof(scalar_t));
}
j++;
}
return vec_size;
}
template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
@ -852,46 +979,70 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
int64_t dim0;
int64_t dim1;
// Adjust block size to map block width to fastest changing dimension of input
// tensor. This grants the best possible memory accessing pattern, given that
// for non-contiguous tensor with space in between, we cannot have perfect
// memory coalescing.
int64_t fastest_moving_stride;
bool reduction_on_fastest_striding_dimension =
(iter.num_reduce_dims() == iter.ndim()) ||
(iter.strides(/*arg=*/input_index)[0] <
iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
// Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
// dim0 & dim1 are more like the upper bound of the block dimension. The
// actual launch config and reduction scheme is determined by setting values
// to `config.input_mult` and `config.output_mult`.
// We try to max out dim1 so that we have enough threads per CTA to deliver
// performance for larger problem size.
if (reduction_on_fastest_striding_dimension) {
// Map block.x to the fastest reducing dimension. It implies:
// 1. block_x_reduce is required.
// 2. block.y now max out to num_outputs.
dim0 = iter.shape()[0];
dim1 = num_outputs;
fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
bool reduction_on_fastest_striding_dimension;
if (iter.ndim() > 0) {
// Adjust block size to map block width to fastest changing dimension of input
// tensor. This grants the best possible memory accessing pattern, given that
// for non-contiguous tensor with space in between, we cannot have perfect
// memory coalescing.
reduction_on_fastest_striding_dimension =
(iter.num_reduce_dims() == iter.ndim()) ||
(iter.strides(/*arg=*/input_index)[0] <
iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
// Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
// dim0 & dim1 are more like the upper bound of the block dimension. The
// actual launch config and reduction scheme is determined by setting values
// to `config.input_mult` and `config.output_mult`.
// We try to max out dim1 so that we have enough threads per CTA to deliver
// performance for larger problem size.
if (reduction_on_fastest_striding_dimension) {
// Map block.x to the fastest reducing dimension. It implies:
// 1. block_x_reduce is required.
// 2. block.y now max out to num_outputs.
dim0 = iter.shape()[0];
dim1 = num_outputs;
fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
} else {
// Map block.x to the fastest non reducing dimension. It implies:
// 1. block_x_reduce is turned off.
// 2. block.y now max out to inputs_per_output.
dim0 = iter.shape()[iter.num_reduce_dims()];
dim1 = inputs_per_output;
fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
}
} else {
// Map block.x to the fastest non reducing dimension. It implies:
// 1. block_x_reduce is turned off.
// 2. block.y now max out to inputs_per_output.
dim0 = iter.shape()[iter.num_reduce_dims()];
dim1 = inputs_per_output;
fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
reduction_on_fastest_striding_dimension = true;
fastest_moving_stride = sizeof(scalar_t);
dim0 = 1;
dim1 = 1;
}
// if the fastest moving dimension is contiguous and large enough, we do vectorized
// load for better performance. Note that if vt0 < ReduceConfig::vec_size, then this
// means the register pressure could be high, in such case, we should avoid vectorization.
// We only vectorize 1D reduction
if (fastest_moving_stride == sizeof(scalar_t) && dim0 > 128 && vt0 >= ReduceConfig::vec_size) {
// TODO: vectorization on output is not supported yet
if (reduction_on_fastest_striding_dimension && iter.num_reduce_dims() == 1) {
config.vectorize = true;
// We do vectorization to gain better memory access, there are two cases which we call
// "vectorize along input" and "vectorize along output". Note that the "input/output"
// here does not mean we are vectorizing load/store instructions. We always only vectorize
// load instructions.
//
// Case 1: "vectorize along input"
// This case happens when we are reducing along fastest moving dimesion. In such case, threads
// with the same threadIdx.y works on the same reduction cooperatively and will produce results
// for the same ouput. In such case, values in each loaded vector always correspond to the same ouput.
//
// Case 2: "vectorize along output"
// This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
// threads with different threadIdx.x are independent and will produce results for different outputs.
// In such case, values in each loaded vector always correspond to different outputs.
if (fastest_moving_stride == sizeof(scalar_t)) {
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
// Case 1: "vectorize along input"
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
// we should avoid vectorization.
config.vectorize_input = true;
} else if (!reduction_on_fastest_striding_dimension) {
// Case 2: "vectorize along output"
config.output_vec_size = get_output_vec_size<scalar_t>(iter);
dim0 /= config.output_vec_size;
}
}

View File

@ -9432,7 +9432,7 @@ class TestTorchDeviceType(TestCase):
@onlyCUDA
@dtypes(torch.half, torch.float, torch.double)
def test_reduction_vectorized_corner(self, device, dtype):
def test_reduction_vectorize_along_input_corner(self, device, dtype):
# 1D case: sum
size = 1024 * 1024 * 64 + 3
shift = 1
@ -9528,6 +9528,29 @@ class TestTorchDeviceType(TestCase):
self.assertEqual(xs1[j].item(), size[1] - i)
self.assertEqual(xs2[j].item(), size[1] - i)
@onlyCUDA
@dtypes(torch.half, torch.float, torch.double)
def test_reduction_vectorize_along_output(self, device, dtype):
def run_test(input_):
M, N = input_.shape
input_.zero_()
for i in range(min(M, N)):
input_[i][i] = 1
output1 = input_.argmax(dim=0)
output2 = input_.sum(dim=0)
for i in range(min(M, N)):
self.assertEqual(output1[i], i)
self.assertEqual(output2[i], 1)
# vec 4
run_test(torch.zeros(64, 64, dtype=dtype, device=device))
# vec 2
run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
run_test(torch.zeros(64, 62, dtype=dtype, device=device))
run_test(torch.zeros(64, 2, dtype=dtype, device=device))
# vec 1
run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
run_test(torch.zeros(64, 61, dtype=dtype, device=device))
run_test(torch.zeros(64, 1, dtype=dtype, device=device))
@slowTest
def test_argminmax_large_axis(self, device):