mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytorch][vulkan] cumsum dim <= 1 (#117580)
Summary: Following the implementation of Softmax, striding over the texture differently based on the desired dimension. Softmax performs a similar operation as cumsum (generally called "scan") iterating over all items in a dimension, but cumsum only needs to iterate once to collate the sum, compared to softmax which needs to iterate multiple times to collect the max and denominator for the final calculation. Similar to the softmax implmentation there's likely opportunities to optimize, but this gets all dims < 4 functional first. Test Plan: `LD_LIBRARY_PATH=third-party/swiftshader/lib/linux-x64/ buck2 run fbcode/mode/dev-nosan //xplat/caffe2:pt_vulkan_api_test_bin -- --gtest_filter="*cumsum*"`: ``` Running main() from third-party/googletest/1.14.0/googletest/googletest/src/gtest_main.cc Note: Google Test filter = *cumsum* [==========] Running 4 tests from 1 test suite. [----------] Global test environment set-up. [----------] 4 tests from VulkanAPITest [ RUN ] VulkanAPITest.cumsum_1d [ OK ] VulkanAPITest.cumsum_1d (93 ms) [ RUN ] VulkanAPITest.cumsum_2d [ OK ] VulkanAPITest.cumsum_2d (74 ms) [ RUN ] VulkanAPITest.cumsum_3d [ OK ] VulkanAPITest.cumsum_3d (105 ms) [ RUN ] VulkanAPITest.cumsum_4d [ OK ] VulkanAPITest.cumsum_4d (73 ms) [----------] 4 tests from VulkanAPITest (346 ms total) [----------] Global test environment tear-down [==========] 4 tests from 1 test suite ran. (346 ms total) [ PASSED ] 4 tests. ``` Differential Revision: D52814000 Pull Request resolved: https://github.com/pytorch/pytorch/pull/117580 Approved by: https://github.com/yipjustin
This commit is contained in:
parent
dd6c0f6844
commit
29f899ef87
|
|
@ -1,27 +0,0 @@
|
|||
#version 450 core
|
||||
#define PRECISION ${PRECISION}
|
||||
#define FORMAT ${FORMAT}
|
||||
|
||||
layout(std430) buffer;
|
||||
|
||||
/* Qualifiers: layout - storage - precision - memory */
|
||||
|
||||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
|
||||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
|
||||
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
|
||||
int axis;
|
||||
} uBlock;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
||||
|
||||
void main() {
|
||||
const ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||||
|
||||
ivec3 spos = pos;
|
||||
vec4 sum = vec4(0);
|
||||
for(spos[uBlock.axis] = 0; spos!=pos; ++spos[uBlock.axis]) {
|
||||
sum += texelFetch(uInput, spos, 0);
|
||||
}
|
||||
sum += texelFetch(uInput, spos, 0);
|
||||
imageStore(uOutput, pos, sum);
|
||||
}
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
#version 450 core
|
||||
#define PRECISION ${PRECISION}
|
||||
#define FORMAT ${FORMAT}
|
||||
|
||||
layout(std430) buffer;
|
||||
|
||||
/* Qualifiers: layout - storage - precision - memory */
|
||||
|
||||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
|
||||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
|
||||
|
||||
/*
|
||||
* Params Buffer
|
||||
* input_shader_extents is the dimensions of the Vulkan 3D texture XYZ
|
||||
* with a zero pad at W.
|
||||
* input_tensor_dims is the dimensions of the NCHW PyTorch Tensor.
|
||||
* input_dim_stride is the stride to include elements along the scan
|
||||
* dimension calculation. early_exit is the global workgroup position-based
|
||||
* condition for unnecessary invocations to exit.
|
||||
*/
|
||||
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
|
||||
ivec4 input_shader_extents;
|
||||
ivec4 input_tensor_dims;
|
||||
ivec4 input_dim_stride;
|
||||
ivec4 early_exit;
|
||||
} uBlock;
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
||||
|
||||
/*
|
||||
* This shader can compute cumsum along batch, height, and width.
|
||||
*/
|
||||
void main() {
|
||||
const ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||||
if (!all(lessThan(pos, uBlock.early_exit.xyz))) {
|
||||
return;
|
||||
}
|
||||
ivec3 cand_pos = pos;
|
||||
vec4 sum = vec4(0, 0, 0, 0);
|
||||
while (all(lessThan(cand_pos, uBlock.input_shader_extents.xyz))) {
|
||||
sum += texelFetch(uInput, cand_pos, 0);
|
||||
imageStore(uOutput, cand_pos, sum);
|
||||
cand_pos += uBlock.input_dim_stride.xyz;
|
||||
}
|
||||
}
|
||||
68
aten/src/ATen/native/vulkan/glsl/cumsum_channel.glsl
Normal file
68
aten/src/ATen/native/vulkan/glsl/cumsum_channel.glsl
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
#version 450 core
|
||||
#define PRECISION ${PRECISION}
|
||||
#define FORMAT ${FORMAT}
|
||||
|
||||
layout(std430) buffer;
|
||||
|
||||
/* Qualifiers: layout - storage - precision - memory */
|
||||
|
||||
layout(set = 0, binding = 0, FORMAT) uniform PRECISION restrict writeonly image3D uOutput;
|
||||
layout(set = 0, binding = 1) uniform PRECISION sampler3D uInput;
|
||||
|
||||
/*
|
||||
* Params Buffer
|
||||
* input_shader_extents is the dimensions of the Vulkan 3D texture XYZ
|
||||
* with a zero pad at W.
|
||||
* input_tensor_dims is the dimensions of the NCHW PyTorch Tensor.
|
||||
* input_dim_stride is the stride to include elements along the scan
|
||||
* dimension calculation. early_exit is the global workgroup position-based
|
||||
* condition for unnecessary invocations to exit.
|
||||
*/
|
||||
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
|
||||
ivec4 input_shader_extents;
|
||||
ivec4 input_tensor_dims;
|
||||
ivec4 input_dim_stride;
|
||||
ivec4 early_exit;
|
||||
} uBlock;
|
||||
|
||||
/*
|
||||
* Local Work Group Size
|
||||
*/
|
||||
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
|
||||
|
||||
void main() {
|
||||
const ivec3 pos = ivec3(gl_GlobalInvocationID);
|
||||
// how "wide" a batch is in terms of z. Only have one invocation per batch,
|
||||
// as one batch width has elements from every channel in-memory.
|
||||
if (!all(lessThan(pos, uBlock.early_exit.xyz))) {
|
||||
return;
|
||||
}
|
||||
const int b_stride = int(ceil(uBlock.input_tensor_dims.y / 4.0));
|
||||
const ivec3 src_pos = ivec3(pos.x, pos.y, pos.z * b_stride);
|
||||
// tail case, padded zeros in memory if tensor's channel dim % 4 != 0
|
||||
uint tail_case_size = uBlock.input_tensor_dims.y % 4;
|
||||
if (tail_case_size == 0) {
|
||||
tail_case_size = 4;
|
||||
}
|
||||
|
||||
float sum = 0;
|
||||
for (int c = 0; c < b_stride - 1; c++) {
|
||||
const ivec3 dst_pos = ivec3(src_pos.x, src_pos.y, src_pos.z + c);
|
||||
const vec4 c_texel =
|
||||
texelFetch(uInput, ivec3(src_pos.x, src_pos.y, src_pos.z + c), 0);
|
||||
vec4 out_texel = vec4(0, 0, 0, 0);
|
||||
for (int t = 0; t < 4; t++) {
|
||||
sum += c_texel[t];
|
||||
out_texel[t] = sum;
|
||||
}
|
||||
imageStore(uOutput, dst_pos, out_texel);
|
||||
}
|
||||
ivec3 dst_pos = ivec3(src_pos.x, src_pos.y, src_pos.z + b_stride - 1);
|
||||
vec4 c_texel = texelFetch(uInput, dst_pos, 0);
|
||||
vec4 out_texel = vec4(0, 0, 0, 0);
|
||||
for (int t = 0; t < tail_case_size; t++) {
|
||||
sum += c_texel[t];
|
||||
out_texel[t] = sum;
|
||||
}
|
||||
imageStore(uOutput, dst_pos, out_texel);
|
||||
}
|
||||
|
|
@ -10,28 +10,103 @@ namespace {
|
|||
|
||||
using namespace api::utils;
|
||||
|
||||
void set_cumsum_kernel_params(
|
||||
const long long num_dims,
|
||||
const long long dim,
|
||||
const IntArrayRef v_input_sizes,
|
||||
api::ShaderInfo& shader_descriptor,
|
||||
api::utils::ivec4& input_shader_extents,
|
||||
api::utils::ivec4& early_exit,
|
||||
api::utils::ivec4& input_dim_stride,
|
||||
api::utils::ivec4& input_tensor_dims) {
|
||||
if (num_dims == 1) {
|
||||
early_exit.data[0u] = 1;
|
||||
input_dim_stride.data[0u] = 1;
|
||||
shader_descriptor = VK_KERNEL(cumsum_batch_height_width);
|
||||
} else if (num_dims == 2) {
|
||||
// for height, width dim case, we can reuse a single shader
|
||||
// with vectorized parameters
|
||||
shader_descriptor = VK_KERNEL(cumsum_batch_height_width);
|
||||
if (dim == 0) {
|
||||
early_exit.data[1u] = 1;
|
||||
input_dim_stride.data[1u] = 1;
|
||||
} else { // dim == 1
|
||||
early_exit.data[0u] = 1;
|
||||
input_dim_stride.data[0u] = 1;
|
||||
}
|
||||
} else if (num_dims == 3) {
|
||||
for (uint32_t i = 0; i < num_dims; i++) {
|
||||
input_tensor_dims.data[i + 1] = safe_downcast<int32_t>(v_input_sizes[i]);
|
||||
}
|
||||
if (dim == 0) {
|
||||
early_exit.data[2u] = 1;
|
||||
input_dim_stride.data[2u] = 1;
|
||||
shader_descriptor = VK_KERNEL(cumsum_channel);
|
||||
} else if (dim == 1) {
|
||||
// for height, width dim case, we can reuse a single shader
|
||||
// with vectorized parameters
|
||||
early_exit.data[1u] = 1;
|
||||
input_dim_stride.data[1u] = 1;
|
||||
shader_descriptor = VK_KERNEL(cumsum_batch_height_width);
|
||||
} else { // dim == 2
|
||||
early_exit.data[0u] = 1;
|
||||
input_dim_stride.data[0u] = 1;
|
||||
shader_descriptor = VK_KERNEL(cumsum_batch_height_width);
|
||||
}
|
||||
} else {
|
||||
// assume num_dims is 4
|
||||
for (uint32_t i = 0; i < num_dims; i++) {
|
||||
input_tensor_dims.data[i] = safe_downcast<int32_t>(v_input_sizes[i]);
|
||||
}
|
||||
if (dim == 1) {
|
||||
// for 4-rank Tensor, scan along channel dim case, the memory layout
|
||||
// forces a different shader algorithm than other dims
|
||||
input_shader_extents.data[2u] =
|
||||
v_input_sizes[Layout::Activation4D::batch];
|
||||
shader_descriptor = VK_KERNEL(cumsum_channel);
|
||||
} else {
|
||||
// for batch, height, width dim case, we can reuse a single shader
|
||||
// with vectorized parameters
|
||||
if (dim == 0) {
|
||||
early_exit.data[2u] = safe_downcast<int32_t>(
|
||||
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
|
||||
input_dim_stride.data[2u] = safe_downcast<int32_t>(
|
||||
std::ceil(v_input_sizes[Layout::Activation4D::channels] / 4.0));
|
||||
} else if (dim == 2) {
|
||||
early_exit.data[1u] = 1;
|
||||
input_dim_stride.data[1u] = 1;
|
||||
} else { // dim == 3
|
||||
early_exit.data[0u] = 1;
|
||||
input_dim_stride.data[0u] = 1;
|
||||
}
|
||||
shader_descriptor = VK_KERNEL(cumsum_batch_height_width);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Tensor cumsum(
|
||||
const at::Tensor& input_arg,
|
||||
const int64_t dim,
|
||||
const int64_t dim_arg,
|
||||
const c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(
|
||||
input_arg.dim() <= 4, "Vulkan cumsum expects input dimension <= 4!");
|
||||
input_arg.dim() >= 1 && input_arg.dim() <= 4,
|
||||
"Vulkan cumsum expects 1 <= input dimension <= 4, Tensor input dimensions ",
|
||||
input_arg.dim());
|
||||
|
||||
TORCH_CHECK(
|
||||
get_dim<Dim4D::Batch>(input_arg) == 1,
|
||||
"Vulkan cumsum expects batch size <= 1!");
|
||||
dim_arg < input_arg.dim(),
|
||||
"cumsum dim input was ",
|
||||
dim_arg,
|
||||
" out of range for Tensor input with dimensions ",
|
||||
input_arg.dim());
|
||||
|
||||
TORCH_CHECK(dim < 4, "Vulkan cumsum expects dim < 4!");
|
||||
|
||||
if (dim <= 1) {
|
||||
// TODO: dim<0, dim=0, dim=1(z axis)
|
||||
TORCH_CHECK(false, "Not implemented!");
|
||||
}
|
||||
int64_t dim = utils::normalize(dim_arg, input_arg.dim());
|
||||
|
||||
api::Context* const context = api::context();
|
||||
|
||||
const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
|
||||
const vTensor& v_input = convert(input);
|
||||
const IntArrayRef v_input_sizes = v_input.sizes();
|
||||
|
||||
vTensor v_output{
|
||||
context,
|
||||
|
|
@ -39,24 +114,66 @@ Tensor cumsum(
|
|||
v_input.dtype(),
|
||||
};
|
||||
|
||||
const struct Block final {
|
||||
int32_t axis;
|
||||
} block{
|
||||
(3 - safe_downcast<int32_t>(dim)),
|
||||
const api::utils::uvec3 global_workgroup_extents = v_output.extents();
|
||||
api::utils::ivec4 input_shader_extents = {
|
||||
safe_downcast<int32_t>(v_input.extents().data[0u]),
|
||||
safe_downcast<int32_t>(v_input.extents().data[1u]),
|
||||
safe_downcast<int32_t>(v_input.extents().data[2u]),
|
||||
0 // zero pad
|
||||
};
|
||||
// early_exit is the global workgroup position-based condition for
|
||||
// unnecessary invocations to exit.
|
||||
api::utils::ivec4 early_exit = {
|
||||
safe_downcast<int32_t>(v_input.extents().data[0u]),
|
||||
safe_downcast<int32_t>(v_input.extents().data[1u]),
|
||||
safe_downcast<int32_t>(v_input.extents().data[2u]),
|
||||
0 // zero pad
|
||||
};
|
||||
// for batch/height/width, they share the same shader
|
||||
// vectorized by input_dim_stride for each dimension case
|
||||
api::utils::ivec4 input_dim_stride = {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0, // zero pad
|
||||
};
|
||||
api::utils::ivec4 input_tensor_dims = {
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
};
|
||||
api::ShaderInfo shader_descriptor;
|
||||
set_cumsum_kernel_params(
|
||||
input_arg.dim(),
|
||||
dim,
|
||||
v_input_sizes,
|
||||
shader_descriptor,
|
||||
input_shader_extents,
|
||||
early_exit,
|
||||
input_dim_stride,
|
||||
input_tensor_dims);
|
||||
|
||||
const struct Block final {
|
||||
ivec4 input_shader_extents;
|
||||
ivec4 input_tensor_dims;
|
||||
ivec4 input_dim_stride;
|
||||
ivec4 early_exit;
|
||||
} block{
|
||||
input_shader_extents, input_tensor_dims, input_dim_stride, early_exit};
|
||||
|
||||
api::UniformParamsBuffer params(context, block);
|
||||
api::PipelineBarrier pipeline_barrier{};
|
||||
|
||||
context->submit_compute_job(
|
||||
// shader descriptor
|
||||
VK_KERNEL(cumsum),
|
||||
shader_descriptor,
|
||||
// pipeline barrier
|
||||
pipeline_barrier,
|
||||
// global work group size
|
||||
v_input.extents(),
|
||||
global_workgroup_extents,
|
||||
// local work group size
|
||||
adaptive_work_group_size(v_output.extents()),
|
||||
adaptive_work_group_size(global_workgroup_extents),
|
||||
// fence handle
|
||||
VK_NULL_HANDLE,
|
||||
// shader arguments
|
||||
|
|
|
|||
|
|
@ -2170,31 +2170,39 @@ TEST_F(VulkanAPITest, copy) {
|
|||
ASSERT_TRUE(check);
|
||||
}
|
||||
|
||||
TEST_F(VulkanAPITest, cumsum) {
|
||||
c10::InferenceMode mode;
|
||||
void test_cumsum(const at::IntArrayRef input_shape, const int64_t dim) {
|
||||
const auto in_cpu = at::rand(input_shape, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
|
||||
const auto in_cpu = at::rand({1, 17, 37, 49}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
// 0 do nothing
|
||||
// 1 frame
|
||||
// not implemented
|
||||
|
||||
// 2 height
|
||||
const auto out_cpu2 = at::cumsum(in_cpu, 2);
|
||||
const auto out_vulkan2 = at::cumsum(in_cpu.vulkan(), 2);
|
||||
const auto check2 = almostEqual(out_cpu2, out_vulkan2.cpu());
|
||||
if (!check2) {
|
||||
showRtol(out_cpu2, out_vulkan2.cpu());
|
||||
const auto out_cpu = at::cumsum(in_cpu, dim);
|
||||
const auto out_vulkan = at::cumsum(in_cpu.vulkan(), dim);
|
||||
const auto check = almostEqual(out_cpu, out_vulkan.cpu());
|
||||
if (!check) {
|
||||
showRtol(out_cpu, out_vulkan.cpu());
|
||||
}
|
||||
ASSERT_TRUE(check2);
|
||||
ASSERT_TRUE(check);
|
||||
}
|
||||
|
||||
// 3 width
|
||||
const auto out_cpu3 = at::cumsum(in_cpu, 3);
|
||||
const auto out_vulkan3 = at::cumsum(in_cpu.vulkan(), 3);
|
||||
const auto check3 = almostEqual(out_cpu3, out_vulkan3.cpu());
|
||||
if (!check3) {
|
||||
showRtol(out_cpu3, out_vulkan3.cpu());
|
||||
TEST_F(VulkanAPITest, cumsum_1d) {
|
||||
test_cumsum({37}, 0);
|
||||
test_cumsum({37}, -1);
|
||||
}
|
||||
|
||||
TEST_F(VulkanAPITest, cumsum_2d) {
|
||||
for (int64_t i = -1; i <= 1; i++) {
|
||||
test_cumsum({17, 37}, i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VulkanAPITest, cumsum_3d) {
|
||||
for (int64_t i = -2; i <= 2; i++) {
|
||||
test_cumsum({17, 37, 49}, i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(VulkanAPITest, cumsum_4d) {
|
||||
for (int64_t i = -3; i <= 3; i++) {
|
||||
test_cumsum({12, 17, 37, 49}, i);
|
||||
}
|
||||
ASSERT_TRUE(check3);
|
||||
}
|
||||
|
||||
void test_div(const at::IntArrayRef input_shape, const at::IntArrayRef other_shape) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user