Using std::vector or c10::SmallVector instead of CArray (#160959)

As the title stated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160959
Approved by: https://github.com/Skylion007
This commit is contained in:
FFFrog 2025-08-19 18:45:58 +08:00 committed by PyTorch MergeBot
parent 576a0e64ed
commit 0f801a510f
6 changed files with 38 additions and 37 deletions

View File

@ -7,6 +7,7 @@
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <vector>
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
#include <ATen/Parallel.h> #include <ATen/Parallel.h>
@ -647,10 +648,10 @@ _vec_softmax(
parallel_for( parallel_for(
0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) { 0, outer_size * inner_size, 0, [&](int64_t begin, int64_t end) {
int64_t idx = begin; int64_t idx = begin;
auto temp_vec_input = std::make_unique<float[]>(dim_size * vectorized_step); std::vector<float> temp_vec_input(dim_size * vectorized_step);
auto temp_vec_output = std::make_unique<float[]>(dim_size * vectorized_step); std::vector<float> temp_vec_output(dim_size * vectorized_step);
float* temp_vec_input_data = temp_vec_input.get(); float* temp_vec_input_data = temp_vec_input.data();
float* temp_vec_output_data = temp_vec_output.get(); float* temp_vec_output_data = temp_vec_output.data();
while (idx < end) { while (idx < end) {
int64_t outer_idx = idx / inner_size; int64_t outer_idx = idx / inner_size;
int64_t inner_idx = idx % inner_size; int64_t inner_idx = idx % inner_size;

View File

@ -285,7 +285,7 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
sizeof(algos) / sizeof(algos[0]) == num_algos, sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution forward algorithms"); "Missing cuDNN convolution forward algorithms");
int perf_count; int perf_count;
auto perf_results = std::make_unique<perf_t[]>(num_algos); c10::SmallVector<perf_t, CUDNN_CONVOLUTION_FWD_ALGO_COUNT> perf_results;
if (!benchmark) { if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES( AT_CUDNN_CHECK_WITH_SHAPES(
cudnnGetConvolutionForwardAlgorithm_v7( cudnnGetConvolutionForwardAlgorithm_v7(
@ -296,7 +296,7 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
args.odesc.desc(), args.odesc.desc(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get()), perf_results.data()),
args); args);
} else { } else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
@ -314,7 +314,7 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
args.output.data_ptr(), args.output.data_ptr(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get(), perf_results.data(),
ws.data, ws.data,
ws.size), ws.size),
args); args);
@ -324,7 +324,7 @@ struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
// memory, e.g. a few GBs. // memory, e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache(); c10::cuda::CUDACachingAllocator::emptyCache();
} }
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count); return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
} }
static void getWorkspaceSize( static void getWorkspaceSize(
@ -369,7 +369,8 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
sizeof(algos) / sizeof(algos[0]) == num_algos, sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward data algorithms."); "Missing cuDNN convolution backward data algorithms.");
int perf_count; int perf_count;
auto perf_results = std::make_unique<perf_t[]>(num_algos); c10::SmallVector<perf_t, CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT>
perf_results;
if (!benchmark) { if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES( AT_CUDNN_CHECK_WITH_SHAPES(
cudnnGetConvolutionBackwardDataAlgorithm_v7( cudnnGetConvolutionBackwardDataAlgorithm_v7(
@ -380,7 +381,7 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
args.idesc.desc(), args.idesc.desc(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get()), perf_results.data()),
args); args);
} else { } else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
@ -398,7 +399,7 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
args.input.data_ptr(), args.input.data_ptr(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get(), perf_results.data(),
ws.data, ws.data,
ws.size), ws.size),
args); args);
@ -408,7 +409,7 @@ struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
// memory, e.g. a few GBs. // memory, e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache(); c10::cuda::CUDACachingAllocator::emptyCache();
} }
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count); return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
} }
static void getWorkspaceSize( static void getWorkspaceSize(
@ -456,7 +457,8 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
static_assert( static_assert(
sizeof(algos) / sizeof(algos[0]) == num_algos, sizeof(algos) / sizeof(algos[0]) == num_algos,
"Missing cuDNN convolution backward filter algorithms."); "Missing cuDNN convolution backward filter algorithms.");
auto perf_results = std::make_unique<perf_t[]>(num_algos); c10::SmallVector<perf_t, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT>
perf_results;
int perf_count; int perf_count;
if (!benchmark) { if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES( AT_CUDNN_CHECK_WITH_SHAPES(
@ -468,7 +470,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
args.wdesc.desc(), args.wdesc.desc(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get()), perf_results.data()),
args); args);
} else { } else {
size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos); size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
@ -486,7 +488,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
args.weight.data_ptr(), args.weight.data_ptr(),
num_algos, num_algos,
&perf_count, &perf_count,
perf_results.get(), perf_results.data(),
ws.data, ws.data,
ws.size), ws.size),
args); args);
@ -496,7 +498,7 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
// memory, e.g. a few GBs. // memory, e.g. a few GBs.
c10::cuda::CUDACachingAllocator::emptyCache(); c10::cuda::CUDACachingAllocator::emptyCache();
} }
return getValidAlgorithms<perf_t>(perf_results.get(), args, perf_count); return getValidAlgorithms<perf_t>(perf_results.data(), args, perf_count);
} }
static void getWorkspaceSize( static void getWorkspaceSize(

View File

@ -17,6 +17,7 @@
#include <c10/util/irange.h> #include <c10/util/irange.h>
#include <cstring> #include <cstring>
#include <vector>
namespace at::native { namespace at::native {
@ -53,8 +54,8 @@ static void upsample_nearest2d_out_frame(
return; return;
} }
auto input_offset_arr = std::make_unique<int64_t[]>(output_width); std::vector<int64_t> input_offset_arr(output_width);
int64_t* input_offset = input_offset_arr.get(); int64_t* input_offset = input_offset_arr.data();
for (const auto w2 : c10::irange(output_width)) { for (const auto w2 : c10::irange(output_width)) {
const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width); const int64_t w1 = nn_compute_source_index_fn(width_scale, w2, input_width);

View File

@ -800,7 +800,7 @@ Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor
Tensor indices_dim1 = indices[1].to(ScalarType::Int); Tensor indices_dim1 = indices[1].to(ScalarType::Int);
Tensor indices_dim2 = indices[2].to(ScalarType::Int); Tensor indices_dim2 = indices[2].to(ScalarType::Int);
auto mat_el_end_indices_host = std::make_unique<int64_t[]>(num_matrices); std::vector<int64_t> mat_el_end_indices_host(num_matrices);
{ {
auto& allocator = *::c10::cuda::CUDACachingAllocator::get(); auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
@ -809,14 +809,14 @@ Tensor& bmm_out_sparse_cuda(const SparseTensor& self, const Tensor& mat2, Tensor
search_end_matrix_indices(mat_el_end_indices_device, num_matrices, indices_dim0); search_end_matrix_indices(mat_el_end_indices_device, num_matrices, indices_dim0);
AT_CUDA_CHECK(cudaMemcpy( AT_CUDA_CHECK(cudaMemcpy(
mat_el_end_indices_host.get(), mat_el_end_indices_host.data(),
mat_el_end_indices_device, mat_el_end_indices_device,
num_matrices*sizeof(int64_t), num_matrices*sizeof(int64_t),
cudaMemcpyDeviceToHost cudaMemcpyDeviceToHost
)); ));
} }
// Need a pointer to an array to access within a lambda // Need a pointer to an array to access within a lambda
int64_t* mat_el_end_indices = &mat_el_end_indices_host[0]; int64_t* mat_el_end_indices = mat_el_end_indices_host.data();
Scalar beta = 0; Scalar beta = 0;
Scalar alpha = 1; Scalar alpha = 1;

View File

@ -528,16 +528,16 @@ std::shared_ptr<::gloo::transport::Device> ProcessGroupGloo::
// use. Note: if the hostname does not resolve to an address (e.g. // use. Note: if the hostname does not resolve to an address (e.g.
// because of misconfigured /etc/hosts file), this will not work. // because of misconfigured /etc/hosts file), this will not work.
const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX); const auto hostNameMax = sysconf(_SC_HOST_NAME_MAX);
auto hostname = std::make_unique<char[]>(hostNameMax); std::string hostname(hostNameMax, '\0');
auto rv = gethostname(hostname.get(), hostNameMax); auto rv = gethostname(hostname.data(), hostNameMax);
if (rv != 0) { if (rv != 0) {
C10_THROW_ERROR(DistBackendError, c10::utils::str_error(errno)); C10_THROW_ERROR(DistBackendError, c10::utils::str_error(errno));
} }
// Use this machine's hostname if it resolves to an address. // Use this machine's hostname if it resolves to an address.
if (doesHostnameResolveToUsableAddress(hostname.get())) { if (doesHostnameResolveToUsableAddress(hostname.data())) {
return ::c10d::GlooDeviceFactory::makeDeviceForHostname( return ::c10d::GlooDeviceFactory::makeDeviceForHostname(
hostname.get(), lazyInit); hostname.data(), lazyInit);
} }
// Otherwise, use the loopback address. // Otherwise, use the loopback address.

View File

@ -351,16 +351,14 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
_storage_nbytes); _storage_nbytes);
} }
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) std::string cpu_data;
std::unique_ptr<char[]> cpu_data;
uint8_t* data{}; uint8_t* data{};
if (storage->device_type() == at::kCPU) { if (storage->device_type() == at::kCPU) {
data = static_cast<uint8_t*>(storage->mutable_data()); data = static_cast<uint8_t*>(storage->mutable_data());
} else { } else {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) cpu_data.resize(nbytes);
cpu_data = std::make_unique<char[]>(nbytes); data = (uint8_t*)cpu_data.data();
data = (uint8_t*)cpu_data.get();
} }
// fast track for bytes and little endian // fast track for bytes and little endian
@ -370,24 +368,23 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_readFileRaw(
doRead(file, data, storage->nbytes()); doRead(file, data, storage->nbytes());
} else { } else {
int64_t buffer_size = std::min(size, (int64_t)5000); int64_t buffer_size = std::min(size, (int64_t)5000);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) std::vector<uint8_t> le_buffer;
std::unique_ptr<uint8_t[]> le_buffer( le_buffer.resize(buffer_size * element_size);
new uint8_t[buffer_size * element_size]);
for (int64_t i = 0; i < size; i += buffer_size) { for (int64_t i = 0; i < size; i += buffer_size) {
size_t to_convert = std::min(size - i, buffer_size); size_t to_convert = std::min(size - i, buffer_size);
doRead(file, le_buffer.get(), element_size * to_convert); doRead(file, le_buffer.data(), element_size * to_convert);
// NOLINTNEXTLINE(bugprone-branch-clone) // NOLINTNEXTLINE(bugprone-branch-clone)
if (element_size == 2) { if (element_size == 2) {
torch::utils::THP_decodeBuffer( torch::utils::THP_decodeBuffer(
(int16_t*)data + i, le_buffer.get(), true, to_convert); (int16_t*)data + i, le_buffer.data(), true, to_convert);
} else if (element_size == 4) { } else if (element_size == 4) {
torch::utils::THP_decodeBuffer( torch::utils::THP_decodeBuffer(
(int32_t*)data + i, le_buffer.get(), true, to_convert); (int32_t*)data + i, le_buffer.data(), true, to_convert);
} else if (element_size == 8) { } else if (element_size == 8) {
torch::utils::THP_decodeBuffer( torch::utils::THP_decodeBuffer(
(int64_t*)data + i, le_buffer.get(), true, to_convert); (int64_t*)data + i, le_buffer.data(), true, to_convert);
} }
} }
} }