mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Vectorize on output for reduction kernels (#37206)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37206 Benchmark on P100: https://github.com/zasdfgbnm/things/blob/master/2020Q2/reduction-benchmark-vectorize-output.ipynb ```python import torch print(torch.__version__) print() for i in range(1000): torch.arange(10000, device='cuda') def benchmark(dtype, i): size0 = 2 ** (i // 2) size1 = 2 ** ((i + 1) // 2) a = torch.zeros(size0, size1, device='cuda', dtype=dtype) torch.cuda.synchronize() %timeit a.sum(dtype=dtype, dim=0); torch.cuda.synchronize() for dtype in [torch.int8, torch.half, torch.float, torch.double]: print(dtype) for i in range(18, 30): benchmark(dtype, i) print() ``` Before ``` 1.5.0a0+3bbb36e torch.int8 24.5 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.1 µs ± 216 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 26.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.9 µs ± 132 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 39 µs ± 504 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 59.6 µs ± 244 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 111 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 186 µs ± 300 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 397 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 665 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.45 ms ± 837 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.03 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float16 24.2 µs ± 66.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.6 µs ± 255 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 27.2 µs ± 53.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32 µs ± 91 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 48.1 µs ± 89.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 66.9 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 121 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 218 µs ± 384 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 431 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 854 µs ± 1.23 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.75 ms ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.63 ms ± 849 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float32 24.2 µs ± 117 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 24.4 µs ± 237 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.3 µs ± 34.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 40.5 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 57.4 µs ± 44.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 85.5 µs ± 41.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 106 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 288 µs ± 181 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 557 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1e+03 µs ± 1.27 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.98 ms ± 533 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.8 ms ± 1.98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float64 25 µs ± 54.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 26.9 µs ± 320 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 37.1 µs ± 51.8 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 54.3 µs ± 45.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 84.9 µs ± 65.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 139 µs ± 68.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 275 µs ± 235 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 504 µs ± 702 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 987 µs ± 613 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.84 ms ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.64 ms ± 2.44 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.19 ms ± 1.19 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` After ``` 1.5.0a0+3bbb36e torch.int8 29.8 µs ± 213 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.7 µs ± 1.41 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 33.4 µs ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) 32.5 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 40.6 µs ± 94.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 53.7 µs ± 66.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 68 µs ± 69.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 98.2 µs ± 88.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 116 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 283 µs ± 120 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 522 µs ± 563 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 967 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) torch.float16 29.4 µs ± 68.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.2 µs ± 45.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 30.8 µs ± 41 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 35.3 µs ± 20.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 50.1 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 70.4 µs ± 67.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 101 µs ± 325 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 157 µs ± 179 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 275 µs ± 791 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 486 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 936 µs ± 211 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.85 ms ± 124 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) torch.float32 29.9 µs ± 36.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 29.5 µs ± 108 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 33 µs ± 93.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 46 µs ± 37.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 64 µs ± 73.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 99.4 µs ± 82.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 157 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 265 µs ± 68.8 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 490 µs ± 319 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 960 µs ± 669 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.84 ms ± 632 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.6 ms ± 1.63 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) torch.float64 33.1 µs ± 74.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 36.7 µs ± 86.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 46.7 µs ± 39.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 61.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 100 µs ± 23.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 158 µs ± 202 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 270 µs ± 332 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 491 µs ± 445 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 939 µs ± 339 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) 1.88 ms ± 1.09 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 3.65 ms ± 5.18 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 7.3 ms ± 7.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` Test Plan: Imported from OSS Differential Revision: D21233255 Pulled By: ngimel fbshipit-source-id: d468fddbb228c0c13146dfc6344c470513f9e374
This commit is contained in:
parent
a92231b70e
commit
b10c53e9b8
|
|
@ -8,7 +8,7 @@
|
|||
namespace at { namespace detail {
|
||||
|
||||
template <typename T, int size>
|
||||
struct alignas(16) Array {
|
||||
struct Array {
|
||||
T data[size];
|
||||
|
||||
C10_HOST_DEVICE T operator[](int i) const {
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ std::ostream& operator<<(std::ostream& out, const ReduceConfig& config) {
|
|||
out << config.output_mult[i];
|
||||
}
|
||||
out << "], ";
|
||||
out << "vectorize_input=" << config.vectorize_input << ", ";
|
||||
out << "output_vec_size=" << config.output_vec_size << ", ";
|
||||
out << "block_width=" << config.block_width << ", ";
|
||||
out << "block_height=" << config.block_height << ", ";
|
||||
out << "num_threads=" << config.num_threads << ", ";
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ struct ReduceConfig {
|
|||
static constexpr int CTA = 2;
|
||||
|
||||
static constexpr int MAX_NUM_THREADS = 512;
|
||||
static constexpr int vec_size = 4;
|
||||
static constexpr int input_vec_size = 4;
|
||||
|
||||
ReduceConfig(int element_size_bytes, int num_outputs, int num_inputs)
|
||||
: element_size_bytes(element_size_bytes)
|
||||
|
|
@ -92,14 +92,16 @@ struct ReduceConfig {
|
|||
int block_height;
|
||||
int num_threads;
|
||||
|
||||
bool vectorize = false;
|
||||
bool vectorize_input = false;
|
||||
int output_vec_size = 1;
|
||||
|
||||
void set_block_dimension(int64_t dim0, int64_t dim1) {
|
||||
int dim0_pow2 = dim0 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim0)) : MAX_NUM_THREADS;
|
||||
int dim1_pow2 = dim1 < MAX_NUM_THREADS ? static_cast<int>(last_pow2(dim1)) : MAX_NUM_THREADS;
|
||||
const int max_num_threads = MAX_NUM_THREADS / output_vec_size;
|
||||
int dim0_pow2 = dim0 < max_num_threads ? static_cast<int>(last_pow2(dim0)) : max_num_threads;
|
||||
int dim1_pow2 = dim1 < max_num_threads ? static_cast<int>(last_pow2(dim1)) : max_num_threads;
|
||||
block_width = std::min(dim0_pow2, int(at::cuda::warp_size()));
|
||||
block_height = std::min(dim1_pow2, int(MAX_NUM_THREADS / block_width));
|
||||
block_width = std::min(dim0_pow2, int(MAX_NUM_THREADS / block_height));
|
||||
block_height = std::min(dim1_pow2, int(max_num_threads / block_width));
|
||||
block_width = std::min(dim0_pow2, int(max_num_threads / block_height));
|
||||
num_threads = block_width * block_height;
|
||||
}
|
||||
|
||||
|
|
@ -120,7 +122,7 @@ struct ReduceConfig {
|
|||
}
|
||||
|
||||
dim3 grid() const {
|
||||
return dim3(div_up(num_outputs, step_output), ctas_per_output);
|
||||
return dim3(div_up(num_outputs / output_vec_size, step_output), ctas_per_output);
|
||||
}
|
||||
|
||||
C10_HOST_DEVICE bool should_block_x_reduce() const {
|
||||
|
|
@ -155,13 +157,14 @@ struct ReduceConfig {
|
|||
cta2 * input_mult[CTA]);
|
||||
}
|
||||
|
||||
template <int output_vec_size>
|
||||
C10_HOST_DEVICE int output_idx() const {
|
||||
int lane = threadIdx.x;
|
||||
int warp = threadIdx.y;
|
||||
int cta1 = blockIdx.x;
|
||||
return (lane * output_mult[BLOCK_X] +
|
||||
warp * output_mult[BLOCK_Y] +
|
||||
cta1 * step_output);
|
||||
cta1 * step_output) * output_vec_size;
|
||||
}
|
||||
|
||||
C10_DEVICE int shared_memory_offset(int offset) const {
|
||||
|
|
@ -182,7 +185,7 @@ struct ReduceConfig {
|
|||
block_width <= at::cuda::warp_size())) {
|
||||
return 0;
|
||||
}
|
||||
return element_size_bytes * num_threads;
|
||||
return element_size_bytes * num_threads * output_vec_size;
|
||||
}
|
||||
|
||||
int64_t global_memory_size() const {
|
||||
|
|
@ -191,7 +194,7 @@ struct ReduceConfig {
|
|||
}
|
||||
auto size = (int64_t)element_size_bytes * num_outputs * ctas_per_output;
|
||||
if (!should_block_x_reduce()) {
|
||||
size *= block().x;
|
||||
size *= block().x * output_vec_size;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
|
@ -210,10 +213,10 @@ struct ReduceConfig {
|
|||
|
||||
std::ostream& operator<<(std::ostream& out, const ReduceConfig& config);
|
||||
|
||||
template<int nt, typename R>
|
||||
template<int nt, int output_vec_size, typename R>
|
||||
C10_LAUNCH_BOUNDS_2(nt, 4)
|
||||
__global__ void reduce_kernel(R reduction) {
|
||||
reduction.run();
|
||||
reduction.template run<output_vec_size>();
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
|
|
@ -285,7 +288,7 @@ struct ReduceOp {
|
|||
|
||||
static constexpr float acc_buffer_multiplier = (float)sizeof(arg_t) / sizeof(out_scalar_t);
|
||||
|
||||
static constexpr int vec_size = ReduceConfig::vec_size;
|
||||
static constexpr int input_vec_size = ReduceConfig::input_vec_size;
|
||||
|
||||
ops_t ops;
|
||||
arg_t ident;
|
||||
|
|
@ -336,57 +339,78 @@ struct ReduceOp {
|
|||
}
|
||||
}
|
||||
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE void run() const {
|
||||
extern __shared__ char shared_memory[];
|
||||
index_t output_idx = config.output_idx();
|
||||
index_t output_idx = config.output_idx<output_vec_size>();
|
||||
index_t input_idx = config.input_idx();
|
||||
auto base_offsets = output_calc.get(output_idx);
|
||||
auto base_offsets1 = output_calc.get(output_idx)[1];
|
||||
|
||||
arg_t value = ident;
|
||||
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
|
||||
arg_vec_t value;
|
||||
|
||||
if (output_idx < config.num_outputs && input_idx < config.num_inputs) {
|
||||
const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets[1]);
|
||||
value = thread_reduce(input_slice);
|
||||
const scalar_t* input_slice = (const scalar_t*)((const char*)src + base_offsets1);
|
||||
value = thread_reduce<output_vec_size>(input_slice);
|
||||
}
|
||||
|
||||
if (config.should_block_y_reduce()) {
|
||||
value = block_y_reduce(value, shared_memory);
|
||||
value = block_y_reduce<output_vec_size>(value, shared_memory);
|
||||
}
|
||||
if (config.should_block_x_reduce()) {
|
||||
value = block_x_reduce(value, shared_memory);
|
||||
value = block_x_reduce<output_vec_size>(value, shared_memory);
|
||||
}
|
||||
|
||||
auto out = (out_scalar_t*)((char*)dst[0] + base_offsets[0]);
|
||||
arg_t* acc = nullptr;
|
||||
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
|
||||
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
|
||||
offset_vec_t base_offsets;
|
||||
out_ptr_vec_t out;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
base_offsets[i] = output_calc.get(output_idx + i)[0];
|
||||
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
|
||||
}
|
||||
|
||||
arg_vec_t* acc = nullptr;
|
||||
if (acc_buf != nullptr) {
|
||||
size_t numerator = sizeof(arg_t);
|
||||
size_t denominator = sizeof(out_scalar_t);
|
||||
reduce_fraction(numerator, denominator);
|
||||
acc = (arg_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
|
||||
acc = (arg_vec_t*)((char*)acc_buf + (base_offsets[0] * numerator / denominator));
|
||||
}
|
||||
|
||||
if (config.should_global_reduce()) {
|
||||
value = global_reduce(value, acc, shared_memory);
|
||||
value = global_reduce<output_vec_size>(value, acc, shared_memory);
|
||||
} else if (config.should_store(output_idx)) {
|
||||
if (accumulate) {
|
||||
value = ops.translate_idx(value, base_idx);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.translate_idx(value[i], base_idx);
|
||||
}
|
||||
}
|
||||
|
||||
if (acc == nullptr) {
|
||||
if (accumulate) {
|
||||
value = accumulate_in_output<can_accumulate_in_output>(out, value);
|
||||
value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
|
||||
}
|
||||
if (final_output) {
|
||||
set_results_to_output(value, base_offsets[0]);
|
||||
set_results_to_output<output_vec_size>(value, base_offsets);
|
||||
} else {
|
||||
*out = get_accumulated_output<can_accumulate_in_output>(out, value);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
*(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (accumulate) {
|
||||
value = ops.combine(*acc, value);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine((*acc)[i], value[i]);
|
||||
}
|
||||
}
|
||||
if (final_output) {
|
||||
set_results_to_output(value, base_offsets[0]);
|
||||
set_results_to_output<output_vec_size>(value, base_offsets);
|
||||
} else {
|
||||
*acc = value;
|
||||
}
|
||||
|
|
@ -394,30 +418,32 @@ struct ReduceOp {
|
|||
}
|
||||
}
|
||||
|
||||
C10_DEVICE arg_t thread_reduce(const scalar_t* data) const {
|
||||
if (config.vectorize) {
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce(const scalar_t* data) const {
|
||||
if (config.vectorize_input) {
|
||||
assert(output_vec_size == 1);
|
||||
// reduce at the header of input_slice where memory is not aligned,
|
||||
// so that thread_reduce will have an aligned memory to work on.
|
||||
return vectorized_thread_reduce_impl(data);
|
||||
return {input_vectorized_thread_reduce_impl(data)};
|
||||
} else {
|
||||
index_t element_stride = input_calc.strides_[0][0] / sizeof(scalar_t);
|
||||
bool is_contiguous = (input_calc.dims == 1 && element_stride == 1);
|
||||
if (is_contiguous) {
|
||||
return thread_reduce_impl(data, [](index_t idx) { return idx; });
|
||||
return thread_reduce_impl<output_vec_size>(data, [](index_t idx) { return idx; });
|
||||
} else if (input_calc.dims == 1) {
|
||||
return thread_reduce_impl(data, [&](index_t idx) { return idx * element_stride; });
|
||||
return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return idx * element_stride; });
|
||||
} else {
|
||||
return thread_reduce_impl(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
|
||||
return thread_reduce_impl<output_vec_size>(data, [&](index_t idx) { return input_calc.get(idx)[0] / sizeof(scalar_t); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEVICE arg_t vectorized_thread_reduce_impl(const scalar_t* data) const {
|
||||
C10_DEVICE arg_t input_vectorized_thread_reduce_impl(const scalar_t* data) const {
|
||||
index_t end = config.num_inputs;
|
||||
|
||||
// Handle the head of input slice where data is not aligned
|
||||
arg_t value = ident;
|
||||
constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, vec_size>);
|
||||
constexpr int align_bytes = alignof(at::native::memory::aligned_vector<scalar_t, input_vec_size>);
|
||||
constexpr int align_elements = align_bytes / sizeof(scalar_t);
|
||||
int shift = ((uint64_t)data) % align_bytes / sizeof(scalar_t);
|
||||
if (shift > 0) {
|
||||
|
|
@ -432,41 +458,41 @@ struct ReduceOp {
|
|||
}
|
||||
|
||||
// Do the vectorized reduction
|
||||
using load_t = at::native::memory::aligned_vector<scalar_t, vec_size>;
|
||||
using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;
|
||||
|
||||
index_t idx = config.input_idx();
|
||||
const index_t stride = config.step_input;
|
||||
|
||||
// Multiple accumulators to remove dependency between unrolled loops.
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
arg_t value_list[vec_size];
|
||||
arg_t value_list[input_vec_size];
|
||||
#else
|
||||
ROCm_Bug<arg_t, vec_size> value_list;
|
||||
ROCm_Bug<arg_t, input_vec_size> value_list;
|
||||
#endif
|
||||
value_list[0] = value;
|
||||
#pragma unroll
|
||||
for (int i = 1; i < vec_size; i++) {
|
||||
for (int i = 1; i < input_vec_size; i++) {
|
||||
value_list[i] = ident;
|
||||
}
|
||||
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
scalar_t values[vec_size];
|
||||
scalar_t values[input_vec_size];
|
||||
#else
|
||||
ROCm_Bug<scalar_t, vec_size> values;
|
||||
ROCm_Bug<scalar_t, input_vec_size> values;
|
||||
#endif
|
||||
load_t *values_vector = reinterpret_cast<load_t*>(&values[0]);
|
||||
|
||||
while (idx * vec_size + vec_size - 1 < end) {
|
||||
while (idx * input_vec_size + input_vec_size - 1 < end) {
|
||||
*values_vector = reinterpret_cast<const load_t*>(data)[idx];
|
||||
#pragma unroll
|
||||
for (index_t i = 0; i < vec_size; i++) {
|
||||
value_list[i] = ops.reduce(value_list[i], values[i], shift + idx * vec_size + i);
|
||||
for (index_t i = 0; i < input_vec_size; i++) {
|
||||
value_list[i] = ops.reduce(value_list[i], values[i], shift + idx * input_vec_size + i);
|
||||
}
|
||||
idx += stride;
|
||||
}
|
||||
|
||||
// tail
|
||||
index_t tail_start = end - end % vec_size;
|
||||
index_t tail_start = end - end % input_vec_size;
|
||||
if (config.should_reduce_tail()) {
|
||||
int idx = tail_start + threadIdx.x;
|
||||
if (idx < end) {
|
||||
|
|
@ -476,45 +502,54 @@ struct ReduceOp {
|
|||
|
||||
// combine accumulators
|
||||
#pragma unroll
|
||||
for (int i = 1; i < vec_size; i++) {
|
||||
for (int i = 1; i < input_vec_size; i++) {
|
||||
value_list[0] = ops.combine(value_list[0], value_list[i]);
|
||||
}
|
||||
return value_list[0];
|
||||
}
|
||||
|
||||
template<typename offset_calc_t>
|
||||
C10_DEVICE arg_t thread_reduce_impl(const scalar_t* data, offset_calc_t calc) const {
|
||||
template <int output_vec_size, typename offset_calc_t>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> thread_reduce_impl(const scalar_t* data_, offset_calc_t calc) const {
|
||||
index_t idx = config.input_idx();
|
||||
const index_t end = config.num_inputs;
|
||||
const index_t stride = config.step_input;
|
||||
|
||||
// Multiple accumulators to remove dependency between unrolled loops.
|
||||
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
|
||||
using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
|
||||
const load_t* data = reinterpret_cast<const load_t*>(data_);
|
||||
|
||||
// Multiple accumulators to remove dependency between unrolled loops.
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
arg_t value_list[vt0];
|
||||
arg_vec_t value_list[vt0];
|
||||
#else
|
||||
ROCm_Bug<arg_t, vt0> value_list;
|
||||
ROCm_Bug<arg_vec_t, vt0> value_list;
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < vt0; i++) {
|
||||
value_list[i] = ident;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < output_vec_size; j++) {
|
||||
value_list[i][j] = ident;
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
scalar_t values[vt0];
|
||||
load_t values[vt0];
|
||||
#else
|
||||
ROCm_Bug<scalar_t, vt0> values;
|
||||
ROCm_Bug<load_t, vt0> values;
|
||||
#endif
|
||||
|
||||
while (idx + (vt0 - 1) * stride < end) {
|
||||
#pragma unroll
|
||||
for (index_t i = 0; i < vt0; i++) {
|
||||
values[i] = data[calc(idx + i * stride)];
|
||||
values[i] = data[calc(idx + i * stride) / output_vec_size];
|
||||
}
|
||||
#pragma unroll
|
||||
for (index_t i = 0; i < vt0; i++) {
|
||||
value_list[i] = ops.reduce(value_list[i], values[i], idx + i * stride);
|
||||
#pragma unroll
|
||||
for (index_t j = 0; j < output_vec_size; j++) {
|
||||
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
|
||||
}
|
||||
}
|
||||
idx += stride * vt0;
|
||||
}
|
||||
|
|
@ -526,7 +561,7 @@ struct ReduceOp {
|
|||
if (idx >= end) {
|
||||
break;
|
||||
}
|
||||
values[i] = data[calc(idx)];
|
||||
values[i] = data[calc(idx) / output_vec_size];
|
||||
idx += stride;
|
||||
}
|
||||
idx = idx_;
|
||||
|
|
@ -535,29 +570,40 @@ struct ReduceOp {
|
|||
if (idx >= end) {
|
||||
break;
|
||||
}
|
||||
value_list[i] = ops.reduce(value_list[i], values[i], idx);
|
||||
#pragma unroll
|
||||
for (index_t j = 0; j < output_vec_size; j++) {
|
||||
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
|
||||
}
|
||||
idx += stride;
|
||||
}
|
||||
|
||||
// combine accumulators
|
||||
#pragma unroll
|
||||
for (int i = 1; i < vt0; i++) {
|
||||
value_list[0] = ops.combine(value_list[0], value_list[i]);
|
||||
#pragma unroll
|
||||
for (index_t j = 0; j < output_vec_size; j++) {
|
||||
value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
|
||||
}
|
||||
}
|
||||
return value_list[0];
|
||||
}
|
||||
|
||||
C10_DEVICE arg_t block_x_reduce(arg_t value, char* shared_memory) const {
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_x_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
|
||||
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
|
||||
int dim_x = blockDim.x;
|
||||
arg_t* shared = (arg_t*)shared_memory;
|
||||
args_vec_t* shared = (args_vec_t*)shared_memory;
|
||||
if (dim_x > warpSize) {
|
||||
int address_base = threadIdx.x + threadIdx.y*blockDim.x;
|
||||
shared[address_base] = value;
|
||||
for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) {
|
||||
arg_t other = shared[address_base + offset];
|
||||
value = ops.combine(value, other);
|
||||
args_vec_t other = shared[address_base + offset];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine(value[i], other[i]);
|
||||
}
|
||||
shared[address_base] = value;
|
||||
}
|
||||
}
|
||||
|
|
@ -567,20 +613,28 @@ struct ReduceOp {
|
|||
__syncthreads();
|
||||
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
arg_t other = ops.warp_shfl_down(value, offset);
|
||||
value = ops.combine(value, other);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = ops.warp_shfl_down(value[i], offset);
|
||||
value[i] = ops.combine(value[i], other);
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
C10_DEVICE arg_t block_y_reduce(arg_t value, char* shared_memory) const {
|
||||
arg_t* shared = (arg_t*)shared_memory;
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> block_y_reduce(at::detail::Array<arg_t, output_vec_size> value, char* shared_memory) const {
|
||||
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
|
||||
args_vec_t* shared = (args_vec_t*)shared_memory;
|
||||
shared[config.shared_memory_offset(0)] = value;
|
||||
for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) {
|
||||
__syncthreads();
|
||||
if (threadIdx.y < offset && threadIdx.y + offset < blockDim.y) {
|
||||
arg_t other = shared[config.shared_memory_offset(offset)];
|
||||
value = ops.combine(value, other);
|
||||
args_vec_t other = shared[config.shared_memory_offset(offset)];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine(value[i], other[i]);
|
||||
}
|
||||
shared[config.shared_memory_offset(0)] = value;
|
||||
}
|
||||
}
|
||||
|
|
@ -601,12 +655,18 @@ struct ReduceOp {
|
|||
return is_last_block_done_shared;
|
||||
}
|
||||
|
||||
template <bool can_acc>
|
||||
C10_DEVICE arg_t accumulate_in_output(
|
||||
out_scalar_t* out, arg_t value,
|
||||
template <int output_vec_size, bool can_acc>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
|
||||
at::detail::Array<out_scalar_t*, output_vec_size> out,
|
||||
at::detail::Array<arg_t, output_vec_size> value,
|
||||
typename std::enable_if<can_acc>::type* = nullptr
|
||||
) const {
|
||||
return ops.combine(*out, value);
|
||||
at::detail::Array<arg_t, output_vec_size> ret;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
ret[i] = ops.combine(*(out[i]), value[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <bool can_acc>
|
||||
|
|
@ -621,9 +681,10 @@ struct ReduceOp {
|
|||
// This function should never be called --
|
||||
// it's the version of `accumulate_in_output`
|
||||
// when accumulation in the output is not possible.
|
||||
template <bool can_acc>
|
||||
C10_DEVICE arg_t accumulate_in_output(
|
||||
out_scalar_t*, arg_t,
|
||||
template <int output_vec_size, bool can_acc>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> accumulate_in_output(
|
||||
at::detail::Array<out_scalar_t*, output_vec_size>,
|
||||
at::detail::Array<arg_t, output_vec_size>,
|
||||
typename std::enable_if<!can_acc>::type* = nullptr
|
||||
) const {
|
||||
assert(false); // can't use AT_ASSERT in Cuda.
|
||||
|
|
@ -664,18 +725,33 @@ struct ReduceOp {
|
|||
}
|
||||
}
|
||||
|
||||
C10_DEVICE void set_results_to_output(arg_t value, index_t base_offset) const {
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE void set_results_to_output(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<index_t, output_vec_size> base_offset) const {
|
||||
assert(final_output);
|
||||
set_results(ops.project(value), base_offset);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
set_results(ops.project(value[i]), base_offset[i]);
|
||||
}
|
||||
}
|
||||
|
||||
C10_DEVICE arg_t global_reduce(arg_t value, arg_t* acc, char* shared_memory) const {
|
||||
arg_t* reduce_buffer = (arg_t*)cta_buf;
|
||||
index_t output_idx = config.output_idx();
|
||||
auto base_offsets = output_calc.get(output_idx);
|
||||
auto out = (out_scalar_t*)((char*)dst[0] + base_offsets[0]);
|
||||
template <int output_vec_size>
|
||||
C10_DEVICE at::detail::Array<arg_t, output_vec_size> global_reduce(at::detail::Array<arg_t, output_vec_size> value, at::detail::Array<arg_t, output_vec_size> *acc, char* shared_memory) const {
|
||||
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
|
||||
using out_ptr_vec_t = at::detail::Array<out_scalar_t*, output_vec_size>;
|
||||
using offset_vec_t = at::detail::Array<index_t, output_vec_size>;
|
||||
|
||||
bool should_store = config.should_store(config.output_idx());
|
||||
arg_vec_t* reduce_buffer = (arg_vec_t*)cta_buf;
|
||||
index_t output_idx = config.output_idx<output_vec_size>();
|
||||
offset_vec_t base_offsets;
|
||||
out_ptr_vec_t out;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
base_offsets[i] = output_calc.get(output_idx + i)[0];
|
||||
out[i] = (out_scalar_t*)((char*)dst[0] + base_offsets[i]);
|
||||
}
|
||||
|
||||
bool should_store = config.should_store(output_idx);
|
||||
if (should_store) {
|
||||
index_t offset = config.staging_memory_offset(blockIdx.y);
|
||||
reduce_buffer[offset] = value;
|
||||
|
|
@ -692,42 +768,57 @@ struct ReduceOp {
|
|||
index_t step = blockDim.x * blockDim.y;
|
||||
for (; input_offset < config.ctas_per_output; input_offset += step) {
|
||||
index_t idx = config.staging_memory_offset(input_offset);
|
||||
arg_t next = reduce_buffer[idx];
|
||||
value = ops.combine(value, next);
|
||||
arg_vec_t next = reduce_buffer[idx];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine(value[i], next[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
index_t input_offset = threadIdx.y;
|
||||
index_t step = blockDim.y;
|
||||
for (; input_offset < config.ctas_per_output; input_offset += step) {
|
||||
index_t idx = config.staging_memory_offset(input_offset);
|
||||
arg_t next = reduce_buffer[idx];
|
||||
value = ops.combine(value, next);
|
||||
arg_vec_t next = reduce_buffer[idx];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine(value[i], next[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
value = block_y_reduce(value, shared_memory);
|
||||
if (config.should_block_x_reduce()) {
|
||||
value = block_x_reduce(value, shared_memory);
|
||||
value = block_x_reduce<output_vec_size>(value, shared_memory);
|
||||
}
|
||||
if (should_store) {
|
||||
if (accumulate) {
|
||||
value = ops.translate_idx(value, base_idx);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.translate_idx(value[i], base_idx);
|
||||
}
|
||||
}
|
||||
|
||||
if (acc == nullptr) {
|
||||
if (accumulate) {
|
||||
value = accumulate_in_output<can_accumulate_in_output>(out, value);
|
||||
value = accumulate_in_output<output_vec_size, can_accumulate_in_output>(out, value);
|
||||
}
|
||||
if (final_output) {
|
||||
set_results_to_output(value, base_offsets[0]);
|
||||
set_results_to_output<output_vec_size>(value, base_offsets);
|
||||
} else {
|
||||
*out = get_accumulated_output<can_accumulate_in_output>(out, value);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
*(out[i]) = get_accumulated_output<can_accumulate_in_output>(out[i], value[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (accumulate) {
|
||||
value = ops.combine(*acc, value);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
value[i] = ops.combine((*acc)[i], value[i]);
|
||||
}
|
||||
}
|
||||
if (final_output) {
|
||||
set_results_to_output(value, base_offsets[0]);
|
||||
set_results_to_output<output_vec_size>(value, base_offsets);
|
||||
} else {
|
||||
*acc = value;
|
||||
}
|
||||
|
|
@ -739,14 +830,25 @@ struct ReduceOp {
|
|||
}
|
||||
};
|
||||
|
||||
template<int nt, typename R>
|
||||
template<int max_threads, typename R>
|
||||
static void launch_reduce_kernel(const ReduceConfig& config, const R& reduction) {
|
||||
dim3 block = config.block();
|
||||
dim3 grid = config.grid();
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
int shared_memory = config.shared_memory_size();
|
||||
reduce_kernel<nt, R><<<grid, block, shared_memory, stream>>>(reduction);
|
||||
|
||||
switch(config.output_vec_size) {
|
||||
case 4:
|
||||
reduce_kernel<max_threads / 4, 4, R><<<grid, block, shared_memory, stream>>>(reduction);
|
||||
break;
|
||||
case 2:
|
||||
reduce_kernel<max_threads / 2, 2, R><<<grid, block, shared_memory, stream>>>(reduction);
|
||||
break;
|
||||
default:
|
||||
reduce_kernel<max_threads / 1, 1, R><<<grid, block, shared_memory, stream>>>(reduction);
|
||||
}
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
|
|
@ -786,6 +888,31 @@ class AccumulationBuffer {
|
|||
at::DataPtr buffer_;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
int get_output_vec_size(TensorIterator &iter) {
|
||||
int vec_size = 4;
|
||||
auto update_vec_size = [&vec_size](uint64_t n) {
|
||||
while(n % vec_size != 0) {
|
||||
vec_size /= 2;
|
||||
}
|
||||
};
|
||||
|
||||
uint64_t base_address = reinterpret_cast<uint64_t>(iter.data_ptr(iter.noutputs())) / sizeof(scalar_t);
|
||||
update_vec_size(base_address);
|
||||
|
||||
const int output_index = iter.num_reduce_dims();
|
||||
update_vec_size(iter.shape()[output_index]);
|
||||
|
||||
int j = 0;
|
||||
for(auto i : iter.strides(iter.noutputs())) {
|
||||
if (j != output_index) {
|
||||
update_vec_size(i / sizeof(scalar_t));
|
||||
}
|
||||
j++;
|
||||
}
|
||||
return vec_size;
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename out_scalar_t, int vt0=4, typename ops_t, typename ident_t=double>
|
||||
inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t ident=0,
|
||||
AccumulationBuffer* acc_buf_ptr=nullptr, int64_t base_idx=0) {
|
||||
|
|
@ -852,46 +979,70 @@ inline void gpu_reduce_kernel(TensorIterator& iter, const ops_t& ops, ident_t id
|
|||
|
||||
int64_t dim0;
|
||||
int64_t dim1;
|
||||
|
||||
// Adjust block size to map block width to fastest changing dimension of input
|
||||
// tensor. This grants the best possible memory accessing pattern, given that
|
||||
// for non-contiguous tensor with space in between, we cannot have perfect
|
||||
// memory coalescing.
|
||||
int64_t fastest_moving_stride;
|
||||
bool reduction_on_fastest_striding_dimension =
|
||||
(iter.num_reduce_dims() == iter.ndim()) ||
|
||||
(iter.strides(/*arg=*/input_index)[0] <
|
||||
iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
|
||||
// Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
|
||||
// dim0 & dim1 are more like the upper bound of the block dimension. The
|
||||
// actual launch config and reduction scheme is determined by setting values
|
||||
// to `config.input_mult` and `config.output_mult`.
|
||||
// We try to max out dim1 so that we have enough threads per CTA to deliver
|
||||
// performance for larger problem size.
|
||||
if (reduction_on_fastest_striding_dimension) {
|
||||
// Map block.x to the fastest reducing dimension. It implies:
|
||||
// 1. block_x_reduce is required.
|
||||
// 2. block.y now max out to num_outputs.
|
||||
dim0 = iter.shape()[0];
|
||||
dim1 = num_outputs;
|
||||
fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
|
||||
bool reduction_on_fastest_striding_dimension;
|
||||
|
||||
if (iter.ndim() > 0) {
|
||||
// Adjust block size to map block width to fastest changing dimension of input
|
||||
// tensor. This grants the best possible memory accessing pattern, given that
|
||||
// for non-contiguous tensor with space in between, we cannot have perfect
|
||||
// memory coalescing.
|
||||
reduction_on_fastest_striding_dimension =
|
||||
(iter.num_reduce_dims() == iter.ndim()) ||
|
||||
(iter.strides(/*arg=*/input_index)[0] <
|
||||
iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()]);
|
||||
// Notice that dim0 & dim1 does NOT guarantee any launch configuration here!
|
||||
// dim0 & dim1 are more like the upper bound of the block dimension. The
|
||||
// actual launch config and reduction scheme is determined by setting values
|
||||
// to `config.input_mult` and `config.output_mult`.
|
||||
// We try to max out dim1 so that we have enough threads per CTA to deliver
|
||||
// performance for larger problem size.
|
||||
if (reduction_on_fastest_striding_dimension) {
|
||||
// Map block.x to the fastest reducing dimension. It implies:
|
||||
// 1. block_x_reduce is required.
|
||||
// 2. block.y now max out to num_outputs.
|
||||
dim0 = iter.shape()[0];
|
||||
dim1 = num_outputs;
|
||||
fastest_moving_stride = iter.strides(/*arg=*/input_index)[0];
|
||||
} else {
|
||||
// Map block.x to the fastest non reducing dimension. It implies:
|
||||
// 1. block_x_reduce is turned off.
|
||||
// 2. block.y now max out to inputs_per_output.
|
||||
dim0 = iter.shape()[iter.num_reduce_dims()];
|
||||
dim1 = inputs_per_output;
|
||||
fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
|
||||
}
|
||||
} else {
|
||||
// Map block.x to the fastest non reducing dimension. It implies:
|
||||
// 1. block_x_reduce is turned off.
|
||||
// 2. block.y now max out to inputs_per_output.
|
||||
dim0 = iter.shape()[iter.num_reduce_dims()];
|
||||
dim1 = inputs_per_output;
|
||||
fastest_moving_stride = iter.strides(/*arg=*/input_index)[iter.num_reduce_dims()];
|
||||
reduction_on_fastest_striding_dimension = true;
|
||||
fastest_moving_stride = sizeof(scalar_t);
|
||||
dim0 = 1;
|
||||
dim1 = 1;
|
||||
}
|
||||
|
||||
// if the fastest moving dimension is contiguous and large enough, we do vectorized
|
||||
// load for better performance. Note that if vt0 < ReduceConfig::vec_size, then this
|
||||
// means the register pressure could be high, in such case, we should avoid vectorization.
|
||||
// We only vectorize 1D reduction
|
||||
if (fastest_moving_stride == sizeof(scalar_t) && dim0 > 128 && vt0 >= ReduceConfig::vec_size) {
|
||||
// TODO: vectorization on output is not supported yet
|
||||
if (reduction_on_fastest_striding_dimension && iter.num_reduce_dims() == 1) {
|
||||
config.vectorize = true;
|
||||
// We do vectorization to gain better memory access, there are two cases which we call
|
||||
// "vectorize along input" and "vectorize along output". Note that the "input/output"
|
||||
// here does not mean we are vectorizing load/store instructions. We always only vectorize
|
||||
// load instructions.
|
||||
//
|
||||
// Case 1: "vectorize along input"
|
||||
// This case happens when we are reducing along fastest moving dimesion. In such case, threads
|
||||
// with the same threadIdx.y works on the same reduction cooperatively and will produce results
|
||||
// for the same ouput. In such case, values in each loaded vector always correspond to the same ouput.
|
||||
//
|
||||
// Case 2: "vectorize along output"
|
||||
// This case happens when the fastest moving dimesion is not the dimension of reduction. In such case,
|
||||
// threads with different threadIdx.x are independent and will produce results for different outputs.
|
||||
// In such case, values in each loaded vector always correspond to different outputs.
|
||||
if (fastest_moving_stride == sizeof(scalar_t)) {
|
||||
if (reduction_on_fastest_striding_dimension && dim0 > 128 && iter.num_reduce_dims() == 1 && vt0 >= ReduceConfig::input_vec_size) {
|
||||
// Case 1: "vectorize along input"
|
||||
// Note that if vt0 < ReduceConfig::vec_size, then this means the register pressure could be high, in such case,
|
||||
// we should avoid vectorization.
|
||||
config.vectorize_input = true;
|
||||
} else if (!reduction_on_fastest_striding_dimension) {
|
||||
// Case 2: "vectorize along output"
|
||||
config.output_vec_size = get_output_vec_size<scalar_t>(iter);
|
||||
dim0 /= config.output_vec_size;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9432,7 +9432,7 @@ class TestTorchDeviceType(TestCase):
|
|||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
def test_reduction_vectorized_corner(self, device, dtype):
|
||||
def test_reduction_vectorize_along_input_corner(self, device, dtype):
|
||||
# 1D case: sum
|
||||
size = 1024 * 1024 * 64 + 3
|
||||
shift = 1
|
||||
|
|
@ -9528,6 +9528,29 @@ class TestTorchDeviceType(TestCase):
|
|||
self.assertEqual(xs1[j].item(), size[1] - i)
|
||||
self.assertEqual(xs2[j].item(), size[1] - i)
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.half, torch.float, torch.double)
|
||||
def test_reduction_vectorize_along_output(self, device, dtype):
|
||||
def run_test(input_):
|
||||
M, N = input_.shape
|
||||
input_.zero_()
|
||||
for i in range(min(M, N)):
|
||||
input_[i][i] = 1
|
||||
output1 = input_.argmax(dim=0)
|
||||
output2 = input_.sum(dim=0)
|
||||
for i in range(min(M, N)):
|
||||
self.assertEqual(output1[i], i)
|
||||
self.assertEqual(output2[i], 1)
|
||||
# vec 4
|
||||
run_test(torch.zeros(64, 64, dtype=dtype, device=device))
|
||||
# vec 2
|
||||
run_test(torch.zeros(64 * 64 + 2, dtype=dtype, device=device)[2:].view(64, 64))
|
||||
run_test(torch.zeros(64, 62, dtype=dtype, device=device))
|
||||
run_test(torch.zeros(64, 2, dtype=dtype, device=device))
|
||||
# vec 1
|
||||
run_test(torch.zeros(64 * 64 + 1, dtype=dtype, device=device)[1:].view(64, 64))
|
||||
run_test(torch.zeros(64, 61, dtype=dtype, device=device))
|
||||
run_test(torch.zeros(64, 1, dtype=dtype, device=device))
|
||||
|
||||
@slowTest
|
||||
def test_argminmax_large_axis(self, device):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user