mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use sum_integers and multiply_integers (#51146)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51146 Test Plan: Sandcastle tests Reviewed By: ngimel Differential Revision: D25903430 fbshipit-source-id: 329c14018c9e5192864eed88a8ed0a5068ff1c69
This commit is contained in:
parent
bff8194522
commit
fa325d7c9f
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/MatrixRef.h>
|
||||
#include <ATen/VmapTransforms.h>
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/llvmMathExtras.h>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -163,7 +164,7 @@ void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::j
|
|||
auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
|
||||
auto batch_sizes = ArrayRef<int64_t>(
|
||||
first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
|
||||
const auto num_batches = prod_intlist(batch_sizes);
|
||||
const auto num_batches = c10::multiply_integers(batch_sizes);
|
||||
// Without a shape-checking API, we're unable to compute the correct shape of
|
||||
// the output so we just error out.
|
||||
TORCH_CHECK(num_batches > 0,
|
||||
|
|
@ -296,7 +297,7 @@ void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Sta
|
|||
auto num_batch_dims = input_physical_views.front().numBatchDims();
|
||||
auto some_sizes = input_physical_views.front().tensor().sizes();
|
||||
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
|
||||
const auto num_batches = prod_intlist(batch_sizes);
|
||||
const auto num_batches = c10::multiply_integers(batch_sizes);
|
||||
// Without a shape-checking API, we're unable to compute the correct shape of
|
||||
// the output so we just error out.
|
||||
TORCH_CHECK(num_batches > 0,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
|
@ -354,7 +354,7 @@ c10::optional<std::vector<int64_t>> computeStride(
|
|||
// we use the stride as if it were computed via resize.
|
||||
// This could perhaps be combined with the below code, but the complexity
|
||||
// didn't seem worth it.
|
||||
const int64_t numel = prod_intlist(oldshape);
|
||||
const int64_t numel = c10::multiply_integers(oldshape);
|
||||
if (numel == 0 && oldshape.equals(newshape)) {
|
||||
return oldstride.vec();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
#include <ATen/Utils.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
|
||||
#include <stdarg.h>
|
||||
#include <cstdlib>
|
||||
#include <stdexcept>
|
||||
|
|
@ -39,7 +42,7 @@ Tensor empty_cpu(
|
|||
allocator = at::getCPUAllocator();
|
||||
}
|
||||
|
||||
int64_t nelements = prod_intlist(size);
|
||||
int64_t nelements = c10::multiply_integers(size);
|
||||
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(dtype_or_default(dtype_opt));
|
||||
int64_t size_bytes = nelements * dtype.itemsize();
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@
|
|||
|
||||
#include <ATen/core/ATenGeneral.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <ATen/Formatting.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/StorageImpl.h>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <ATen/Formatting.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
|
@ -89,22 +89,6 @@ std::array<int64_t, N> check_intlist(ArrayRef<int64_t> list, const char * name,
|
|||
return res;
|
||||
}
|
||||
|
||||
inline int64_t sum_intlist(ArrayRef<int64_t> list) {
|
||||
return std::accumulate(list.begin(), list.end(), 0ll);
|
||||
}
|
||||
|
||||
//std::accumulate infers return type from `init` type, so if `init` type is not enough to hold the result, computation can overflow
|
||||
//the next 2 functions set `init` type to int64_t to avoid overflow.
|
||||
template<typename C, typename std::enable_if<std::is_integral<typename C::value_type>::value, int>::type = 0>
|
||||
inline int64_t prod_intlist(const C &container){
|
||||
return std::accumulate(container.begin(), container.end(), static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
template<typename Iter,
|
||||
typename std::enable_if<std::is_integral<typename std::iterator_traits<Iter>::value_type>::value, int>::type = 0>
|
||||
inline int64_t prod_intlist(Iter begin, Iter end){
|
||||
return std::accumulate(begin, end, static_cast<int64_t>(1), std::multiplies<int64_t>());
|
||||
}
|
||||
/**
|
||||
* Utility function to static cast input Generator* to
|
||||
* the backend generator type (CPU/CUDAGeneratorImpl etc.)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
#include <limits>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/cpu/DepthwiseConvKernel.h>
|
||||
#include <ATen/native/utils/ParamUtils.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/xnnpack/Engine.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <ATen/Config.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
#if AT_NNPACK_ENABLED()
|
||||
#include <nnpack.h>
|
||||
#endif
|
||||
|
|
@ -177,10 +179,10 @@ auto ConvParams::needs_64bit_indexing_no_split(const at::Tensor& input, const at
|
|||
int64_t outsize = 1;
|
||||
if (transposed) {
|
||||
std::vector<int64_t> o = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
|
||||
outsize = prod_intlist(o.begin() + 1, o.end());
|
||||
outsize = c10::multiply_integers(o.begin() + 1, o.end());
|
||||
} else {
|
||||
std::vector<int64_t> o = conv_output_size(input.sizes(), weight.sizes(), padding, stride, dilation);
|
||||
outsize = prod_intlist(o.begin() + 1, o.end());
|
||||
outsize = c10::multiply_integers(o.begin() + 1, o.end());
|
||||
}
|
||||
return outsize > int_max;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/native/Distance.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -74,7 +75,7 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, c10
|
|||
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
|
||||
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
|
||||
|
||||
const int64_t expand_batch_product = prod_intlist(expand_batch_portion);
|
||||
const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
|
||||
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
|
||||
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
|
||||
|
||||
|
|
@ -147,7 +148,7 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
|
|||
auto device2 = x2.device().type();
|
||||
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
|
||||
IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(x1.dim() - 2, 0));
|
||||
const int64_t batch_product = prod_intlist(batch_tensor1);
|
||||
const int64_t batch_product = c10::multiply_integers(batch_tensor1);
|
||||
Tensor grad_x1 =
|
||||
at::empty_like(x1, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT)
|
||||
.view({batch_product, n, m});
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/Fill.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -104,7 +105,7 @@ Tensor& zero_cpu_(Tensor &self, int64_t nelements) {
|
|||
}
|
||||
|
||||
Tensor& zero_(Tensor &self) {
|
||||
int64_t nelements = at::prod_intlist(self.sizes());
|
||||
int64_t nelements = c10::multiply_integers(self.sizes());
|
||||
if (self.device() == at::kCPU &&
|
||||
self.is_non_overlapping_and_dense() &&
|
||||
nelements < internal::GRAIN_SIZE) {
|
||||
|
|
|
|||
|
|
@ -1,25 +1,28 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/LegacyTHFunctionsCPU.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/native/ReduceOpsUtils.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/TensorIterator.h>
|
||||
#include <ATen/native/LinearAlgebra.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/LegacyTHFunctionsCPU.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -952,7 +955,7 @@ Tensor matmul(
|
|||
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
|
||||
|
||||
const int64_t expand_batch_product =
|
||||
prod_intlist(expand_batch_portion);
|
||||
c10::multiply_integers(expand_batch_portion);
|
||||
|
||||
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
|
||||
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
|
||||
|
|
@ -2049,8 +2052,8 @@ Tensor linalg_tensorinv(const Tensor& self, int64_t ind) {
|
|||
// self[:ind]
|
||||
std::vector<int64_t> shape_start_ind = self.sizes().slice(0, ind).vec();
|
||||
|
||||
int64_t prod_ind_end = std::accumulate(shape_ind_end.cbegin(), shape_ind_end.cend(), int64_t{1}, std::multiplies<int64_t>());
|
||||
int64_t prod_start_ind = std::accumulate(shape_start_ind.cbegin(), shape_start_ind.cend(), int64_t{1}, std::multiplies<int64_t>());
|
||||
int64_t prod_ind_end = c10::multiply_integers(shape_ind_end.cbegin(), shape_ind_end.cend());
|
||||
int64_t prod_start_ind = c10::multiply_integers(shape_start_ind.cbegin(), shape_start_ind.cend());
|
||||
|
||||
// Check whether the self tensor can be reshaped to the 2D square matrix
|
||||
TORCH_CHECK(prod_ind_end == prod_start_ind,
|
||||
|
|
@ -2106,8 +2109,8 @@ Tensor linalg_tensorsolve(const Tensor& self, const Tensor& other, optional<IntA
|
|||
// result_shape is self_.sizes[-(an-other.dim):]
|
||||
std::vector<int64_t> result_shape = self_.sizes().slice(other.dim(), ndim - other.dim()).vec();
|
||||
|
||||
int64_t result_product = std::accumulate(result_shape.begin(), result_shape.end(), int64_t{1}, std::multiplies<int64_t>());
|
||||
int64_t other_product = std::accumulate(other.sizes().begin(), other.sizes().end(), int64_t{1}, std::multiplies<int64_t>());
|
||||
int64_t result_product = c10::multiply_integers(result_shape.begin(), result_shape.end());
|
||||
int64_t other_product = c10::multiply_integers(other.sizes().begin(), other.sizes().end());
|
||||
|
||||
// Check whether the self tensor can be reshaped to the 2D square matrix
|
||||
TORCH_CHECK(result_product == other_product,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
|
||||
|
||||
#include <tuple>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/native/im2col.h>
|
||||
#include <ATen/native/vol2col.h>
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
#include <ATen/native/DilatedConvolutionUtils.h>
|
||||
#include <ATen/native/im2col.h>
|
||||
#include <ATen/native/vol2col.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -182,8 +183,8 @@ void slow_conv_dilated_all_cpu_template(
|
|||
// Temporary buffer:
|
||||
Tensor columns = at::empty({0}, options);
|
||||
if (output.defined() || grad_weight.defined() || grad_input.defined()) {
|
||||
const int64_t m = prod_intlist(kernel_size);
|
||||
const int64_t n = prod_intlist(output_size);
|
||||
const int64_t m = c10::multiply_integers(kernel_size);
|
||||
const int64_t n = c10::multiply_integers(output_size);
|
||||
columns.resize_({nInputPlane * m, n});
|
||||
}
|
||||
// Initialize
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/ResizeCommon.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <TH/THTensor.hpp>
|
||||
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
// TODO: make all operations that resize given outputs use this function
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
|
@ -1893,7 +1894,7 @@ Tensor flatten(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
|||
// of freedom we don't want; for example, consider shape [0, 1, 3, 0], with start_dim=1, end_dim=2.
|
||||
// It's clear we want result shape [0, 3, 0] but passing [0, -1, 0] to infer_size means the -1
|
||||
// can take on any value and satisfy the constraints.
|
||||
auto slice_numel = prod_intlist(self.sizes().slice(start_dim, end_dim - start_dim + 1));
|
||||
auto slice_numel = c10::multiply_integers(self.sizes().slice(start_dim, end_dim - start_dim + 1));
|
||||
std::vector<int64_t> shape;
|
||||
shape.reserve(self.dim() - end_dim + start_dim);
|
||||
for (const auto i : c10::irange(start_dim)) {
|
||||
|
|
@ -1948,7 +1949,7 @@ Tensor unflatten(const Tensor& self, int64_t dim, IntArrayRef sizes, c10::option
|
|||
TORCH_CHECK(sizes.size() > 0, "unflatten: sizes must be non-empty");
|
||||
TORCH_INTERNAL_ASSERT(!names || names->size() == sizes.size());
|
||||
|
||||
const int64_t numel = prod_intlist(sizes);
|
||||
const int64_t numel = c10::multiply_integers(sizes);
|
||||
if (self.has_names()) {
|
||||
TORCH_CHECK(numel == self.size(dim),
|
||||
"unflatten: Provided sizes ", sizes, " don't multiply up to the size of dim ",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/native/UpSample.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace meta {
|
||||
|
|
@ -48,7 +49,7 @@ TORCH_META_FUNC(upsample_nearest2d) (
|
|||
|
||||
// Allow for empty batch size but not other dimensions
|
||||
TORCH_CHECK(
|
||||
input.numel() != 0 || prod_intlist(input.sizes().begin() + 1, input.sizes().end()),
|
||||
input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()),
|
||||
"Non-empty 4D data tensor expected but got a tensor with sizes ",
|
||||
input.sizes());
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/native/cuda/CuFFTUtils.h>
|
||||
#include <ATen/native/utils/ParamsHash.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <list>
|
||||
#include <unordered_map>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
#include <limits>
|
||||
#include <cufft.h>
|
||||
#include <cufftXt.h>
|
||||
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace at { namespace native { namespace detail {
|
||||
|
||||
// Enum representing the FFT type
|
||||
|
|
@ -167,7 +169,7 @@ inline CuFFTDataLayout as_cufft_embed(IntArrayRef strides, IntArrayRef sizes, bo
|
|||
|
||||
const auto last_dim_size = onesided ?
|
||||
sizes[signal_ndim] / 2 + 1 : sizes[signal_ndim];
|
||||
const auto signal_numel = at::prod_intlist(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
|
||||
const auto signal_numel = c10::multiply_integers(sizes.slice(1, sizes.size() - 2)) * last_dim_size;
|
||||
|
||||
// Zero stides are not allowed, even if the batch size is one.
|
||||
// If that happens just set a dummy case
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/native/DilatedConvolutionUtils.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <tuple>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/cuda/im2col.cuh>
|
||||
#include <ATen/native/cuda/vol2col.cuh>
|
||||
#include <ATen/native/DilatedConvolutionUtils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -188,8 +190,8 @@ void slow_conv_dilated_all_cuda_template(
|
|||
int64_t nInputPlane = weight.size(1);
|
||||
int64_t nOutputPlane = weight.size(0);
|
||||
// Temporary buffers:
|
||||
const int64_t m = prod_intlist(kernel_size);
|
||||
const int64_t output_vsize = prod_intlist(output_size);
|
||||
const int64_t m = c10::multiply_integers(kernel_size);
|
||||
const int64_t output_vsize = c10::multiply_integers(output_size);
|
||||
Tensor columns = at::empty({0}, options);
|
||||
if (output.defined() || grad_weight.defined() || grad_input.defined()) {
|
||||
columns.resize_({nInputPlane * m, output_vsize});
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/NumericLimits.cuh>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <THC/THCGeneral.h>
|
||||
#include <THC/THCNumerics.cuh>
|
||||
|
||||
#include <cub/device/device_scan.cuh>
|
||||
|
||||
|
||||
|
|
@ -166,10 +168,10 @@ __host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Te
|
|||
auto sizes = self.sizes();
|
||||
|
||||
// Treat all outer dimensions (i.e. dim_ < dim) as one.
|
||||
const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim);
|
||||
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
|
||||
|
||||
// Treat all inner dimensions (i.e. dim > dimension) as one.
|
||||
const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end());
|
||||
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
|
||||
//for performance reasons, cuda kernels use uint32_t for loops over irows, orows and row,
|
||||
//make sure that input is not bigger than supported by uint32_t
|
||||
check_fits_in_unsigned(num_irows, "num_irows");
|
||||
|
|
@ -420,10 +422,10 @@ __host__ void scan_outer_dim(const Tensor& self, Tensor& result,
|
|||
auto sizes = self.sizes();
|
||||
|
||||
// Treat all outer dimensions (i.e. dim_ < dim) as one.
|
||||
const int64_t num_orows = prod_intlist(sizes.begin(), sizes.begin() + dim);
|
||||
const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);
|
||||
|
||||
// Treat all inner dimensions (i.e. dim > dimension) as one.
|
||||
const int64_t num_irows = prod_intlist(sizes.begin() + dim + 1, sizes.end());
|
||||
const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());
|
||||
|
||||
dim3 threads(std::min(512, int(num_irows)));
|
||||
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <ATen/native/SpectralOpsUtils.h>
|
||||
#include <ATen/native/cuda/CuFFTUtils.h>
|
||||
#include <ATen/native/cuda/CuFFTPlanCache.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <THC/THCTensorSort.cuh>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
|
||||
|
|
@ -19,8 +20,10 @@
|
|||
#include <thrust/unique.h>
|
||||
#include <cufft.h>
|
||||
#include <cufftXt.h>
|
||||
#include <vector>
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
|
||||
namespace at { namespace native {
|
||||
|
||||
|
|
@ -116,7 +119,7 @@ void _fft_fill_with_conjugate_symmetry_cuda_(
|
|||
HermitianSymmetryOffsetCalculator<int64_t> output_offset_calculator(
|
||||
signal_half_sizes, out_strides, mirror_dims, element_size);
|
||||
|
||||
const auto numel = at::prod_intlist(signal_half_sizes);
|
||||
const auto numel = c10::multiply_integers(signal_half_sizes);
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
||||
using namespace cuda::detail;
|
||||
_fft_conjugate_copy_kernel<<<
|
||||
|
|
@ -253,7 +256,7 @@ static inline Tensor _run_cufft(
|
|||
// rescale if requested
|
||||
auto size_last_signal_dim = checked_signal_sizes[signal_ndim - 1];
|
||||
if (norm != fft_norm_mode::none) {
|
||||
auto signal_numel = at::prod_intlist(checked_signal_sizes);
|
||||
auto signal_numel = c10::multiply_integers(checked_signal_sizes);
|
||||
double scale_denom;
|
||||
if (norm == fft_norm_mode::by_root_n) {
|
||||
scale_denom = std::sqrt(static_cast<double>(signal_numel));
|
||||
|
|
|
|||
|
|
@ -1,22 +1,23 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/native/cuda/Resize.cuh>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <THC/THCGeneral.h>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -47,7 +48,7 @@ Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::op
|
|||
check_size_nonnegative(size);
|
||||
|
||||
auto* allocator = at::cuda::getCUDADeviceAllocator();
|
||||
int64_t nelements = prod_intlist(size);
|
||||
int64_t nelements = c10::multiply_integers(size);
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
auto dtype_meta = scalarTypeToTypeMeta(dtype);
|
||||
int64_t size_bytes = nelements * dtype_meta.itemsize();
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
#include <ATen/native/RNN.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/MatrixRef.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/CUDAEvent.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/InitialTensorOptions.h>
|
||||
#include <ATen/MatrixRef.h>
|
||||
#include <ATen/native/RNN.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#if !AT_CUDNN_ENABLED()
|
||||
|
|
@ -423,7 +424,7 @@ namespace {
|
|||
TORCH_INTERNAL_ASSERT(offset_bytes % elem_size == 0, "offset_bytes = ", offset_bytes, "; elem_size = ", elem_size);
|
||||
size_t offset = offset_bytes / elem_size;
|
||||
|
||||
int mat_numel = prod_intlist(filter_dim_a, filter_dim_a + nb_dims);
|
||||
int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
|
||||
// Generate a new parameter tensor which is a view into the weight_buf.
|
||||
std::initializer_list<int64_t> size = {mat_numel, 1};
|
||||
Tensor param = at::empty({0}, weight_buf.options()).set_(weight_buf.storage(), offset, size);
|
||||
|
|
@ -507,7 +508,7 @@ namespace {
|
|||
// (same for the hh weights, and the ih and hh biases).
|
||||
// Since we're storing all the weights in a single tensor anyway,
|
||||
// might as well merge the CUDNN ones into a single tensor as well
|
||||
int mat_numel = prod_intlist(filter_dim_a, filter_dim_a + nb_dims);
|
||||
int mat_numel = c10::multiply_integers(filter_dim_a, filter_dim_a + nb_dims);
|
||||
if (linear_id == 0 || linear_id == num_linear_layers / 2) {
|
||||
// We could also exclude bias params by restricting cudnn_methods to just { cudnnGetRNNLinLayerMatrixParams }
|
||||
// at the very top. However, to do so would throw off the cur_offset account, which is currently a strict
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/CPUApplyUtils.h>
|
||||
#include <ATen/native/group_norm.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <array>
|
||||
#include <functional>
|
||||
|
|
@ -6,12 +13,6 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/CPUApplyUtils.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -107,7 +108,7 @@ Tensor group_norm(
|
|||
|
||||
const auto input_shape = input.sizes();
|
||||
const int64_t HxW =
|
||||
prod_intlist(input_shape.cbegin() + 2, input_shape.cend());
|
||||
c10::multiply_integers(input_shape.cbegin() + 2, input_shape.cend());
|
||||
|
||||
const Tensor kEmpty;
|
||||
const auto& X = input.is_contiguous() ? input : input.contiguous();
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/DispatchStub.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -53,9 +54,9 @@ std::tuple<Tensor, Tensor, Tensor, int64_t, int64_t> _prepare_layer_norm_inputs(
|
|||
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
const int64_t M =
|
||||
prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis);
|
||||
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
|
||||
const int64_t N =
|
||||
prod_intlist(input_shape.cbegin() + axis, input_shape.cend());
|
||||
c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
|
||||
|
||||
const auto& X = input.is_contiguous() ? input : input.contiguous();
|
||||
const auto& gamma = weight.is_contiguous() ? weight : weight.contiguous();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/native/metal/MetalPrepackOpContext.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
|
||||
#if (C10_IOS || TARGET_OS_MAC)
|
||||
|
|
@ -24,7 +25,7 @@ c10::intrusive_ptr<Conv2dOpContext> unpack(
|
|||
const auto ws = weightContig.sizes();
|
||||
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec());
|
||||
auto packedWeight = at::empty(ws);
|
||||
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float);
|
||||
int64_t size_bytes = c10::multiply_integers(ws) * sizeof(float);
|
||||
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes);
|
||||
return c10::make_intrusive<Conv2dOpContext>(
|
||||
std::move(packedWeight),
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@
|
|||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
|
||||
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
|
@ -12,11 +15,7 @@ class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl {
|
|||
Impl(const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides)
|
||||
: _sizes(sizes),
|
||||
_strides(strides),
|
||||
_numel(std::accumulate(
|
||||
std::begin(_sizes),
|
||||
std::end(_sizes),
|
||||
(int64_t)1,
|
||||
std::multiplies<int64_t>())),
|
||||
_numel(c10::multiply_integers(std::begin(_sizes), std::end(_sizes))),
|
||||
_textureImpl(std::make_unique<MPSImageWrapper>(sizes)) {}
|
||||
|
||||
IntArrayRef sizes() const {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
#include <ATen/InferSize.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/UpSample.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -743,8 +744,8 @@ Tensor flatten_using_ints(
|
|||
if (start_dim == end_dim) {
|
||||
return input;
|
||||
}
|
||||
auto slice_numel =
|
||||
prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1));
|
||||
const auto slice_numel =
|
||||
c10::multiply_integers(input.sizes().slice(start_dim, end_dim - start_dim + 1));
|
||||
shape.reserve(input.dim() - end_dim + start_dim);
|
||||
for (int64_t i = 0; i < start_dim; i++) {
|
||||
shape.push_back(input.size(i));
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#include <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
using namespace at::native;
|
||||
|
|
@ -99,7 +100,7 @@ using namespace at::native;
|
|||
auto fp32 = metal::fp16_to_fp32(fp16);
|
||||
std::vector<float> fp32_nchw = metal::NC4_to_NCHW(fp32.data(), outputSize);
|
||||
auto tensor = at::empty(outputSize);
|
||||
int64_t size_bytes = at::prod_intlist(outputSize) * sizeof(float);
|
||||
int64_t size_bytes = c10::multiply_integers(outputSize) * sizeof(float);
|
||||
memcpy(tensor.data_ptr(), fp32_nchw.data(), size_bytes);
|
||||
return tensor;
|
||||
}
|
||||
|
|
@ -233,7 +234,7 @@ using namespace at::native;
|
|||
|
||||
+ (MPSImage*)imageFromHost:(const float*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes {
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
|
||||
// allocte buffer on CPU
|
||||
id<MTLBuffer> buff = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
|
|
@ -268,7 +269,7 @@ using namespace at::native;
|
|||
Sizes:(const std::vector<int64_t>&)sizes
|
||||
CommandBuffer:(MetalCommandBuffer*)cb {
|
||||
NSCAssert(cb, @"CommandBuffer is nil!");
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
|
||||
// allocte buffer on CPU
|
||||
id<MTLBuffer> buff = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
|
|
@ -301,7 +302,7 @@ using namespace at::native;
|
|||
|
||||
+ (void)copyToHost:(float*)dst FromImage:(MPSImage*)image {
|
||||
auto&& sizes = [image sizes];
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
int64_t size_bytes = c10::multiply_integers(sizes) * sizeof(float);
|
||||
id<MTLBuffer> buffer = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
options:MTLResourceOptionCPUCacheModeDefault];
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@
|
|||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
|
|
@ -65,7 +67,7 @@ bool almostEqualVec(
|
|||
}
|
||||
|
||||
typedef bool (^Func)(void);
|
||||
bool TEST(const std::vector<int64_t>& sizes, std::string name, Func block) {
|
||||
bool TEST(const std::vector<int64_t>& sizes, std::string name, Func block) {
|
||||
std::stringstream ss;
|
||||
std::copy(sizes.begin(), sizes.end(), std::ostream_iterator<int>(ss, " "));
|
||||
__block std::string str1 = ss.str();
|
||||
|
|
@ -103,11 +105,7 @@ bool test_nchw_to_nc4_cpu() {
|
|||
__block std::vector<int64_t> size{N, C, H, W};
|
||||
bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto t = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
int len = std::accumulate(
|
||||
std::begin(size),
|
||||
std::end(size),
|
||||
(int64_t)1,
|
||||
std::multiplies<int64_t>());
|
||||
const auto len = c10::multiply_integers(std::begin(size), std::end(size));
|
||||
auto buf =
|
||||
std::vector<float>{t.data_ptr<float>(), t.data_ptr<float>() + len};
|
||||
auto c4 = NCHW_to_NC4((float*)t.data_ptr<float>(), t.sizes().vec());
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/SpectralOpsUtils.h>
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#if !AT_MKL_ENABLED()
|
||||
|
||||
|
|
@ -188,7 +189,7 @@ static void _fft_fill_with_conjugate_symmetry_cpu_(
|
|||
is_mirrored_dim[dim] = true;
|
||||
}
|
||||
|
||||
const auto numel = at::prod_intlist(signal_half_sizes);
|
||||
const auto numel = c10::multiply_integers(signal_half_sizes);
|
||||
AT_DISPATCH_COMPLEX_TYPES(dtype, "_fft_fill_with_conjugate_symmetry", [&] {
|
||||
at::parallel_for(0, numel, at::internal::GRAIN_SIZE,
|
||||
[&](int64_t begin, int64_t end) {
|
||||
|
|
@ -267,7 +268,7 @@ static DftiDescriptor _plan_mkl_fft(
|
|||
}
|
||||
// rescale if requested
|
||||
const auto norm = static_cast<fft_norm_mode>(normalization);
|
||||
int64_t signal_numel = at::prod_intlist(IntArrayRef(sizes.data() + 1, signal_ndim));
|
||||
int64_t signal_numel = c10::multiply_integers(IntArrayRef(sizes.data() + 1, signal_ndim));
|
||||
if (norm != fft_norm_mode::none) {
|
||||
const double scale = (
|
||||
(norm == fft_norm_mode::by_root_n) ?
|
||||
|
|
|
|||
|
|
@ -1,22 +1,18 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
|
||||
#include <ATen/native/quantized/cpu/conv_packed_params.h>
|
||||
#include <ATen/native/quantized/cpu/conv_serialization.h>
|
||||
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/cpu/packed_params.h>
|
||||
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
|
||||
#include <c10/core/QScheme.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
|
||||
#include <ATen/native/quantized/cpu/packed_params.h>
|
||||
|
||||
torch::class_<LinearPackedParamsBase> register_linear_params();
|
||||
torch::class_<EmbeddingPackedParamsBase> register_embedding_params();
|
||||
|
||||
|
|
@ -143,7 +139,7 @@ Tensor MakeStridedQTensorCPU(
|
|||
AT_ASSERT(options.device().is_cpu());
|
||||
at::native::check_size_nonnegative(sizes);
|
||||
auto* allocator = at::getCPUAllocator();
|
||||
const int64_t nelements = at::prod_intlist(sizes);
|
||||
const int64_t nelements = c10::multiply_integers(sizes);
|
||||
auto dtype = options.dtype();
|
||||
TORCH_CHECK(
|
||||
isQIntType(typeMetaToScalarType(dtype)),
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/quantized/cpu/quantized_ops.h>
|
||||
#include <ATen/native/layer_norm.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
|
@ -72,7 +73,7 @@ Tensor quantized_group_norm_impl(
|
|||
const int64_t batches = input_shape[0];
|
||||
const int64_t num_channels = input_shape[1];
|
||||
const int64_t elements_per_batch =
|
||||
prod_intlist(input_shape.cbegin() + 1, input_shape.cend());
|
||||
c10::multiply_integers(input_shape.cbegin() + 1, input_shape.cend());
|
||||
|
||||
const int64_t M = batches * num_groups;
|
||||
const int64_t N = elements_per_batch / num_groups;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh>
|
||||
#include <ATen/AccumulateType.h>
|
||||
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
||||
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <THC/THCTensorSort.cuh>
|
||||
#include <THC/THCThrustAllocator.cuh>
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/device_vector.h>
|
||||
|
|
@ -16,6 +17,7 @@
|
|||
#include <thrust/scan.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <thrust/sort.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include <thrust/transform.h>
|
||||
#include <thrust/unique.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
|
|
@ -132,7 +134,7 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
|
|||
if (newValues.numel() > 0) {
|
||||
const int SZ = 4;
|
||||
values = values.contiguous();
|
||||
int64_t stride = at::prod_intlist(values.sizes().slice(1));
|
||||
int64_t stride = c10::multiply_integers(values.sizes().slice(1));
|
||||
dim3 grid(THCCeilDiv(newNnz, (int64_t) SZ), THCCeilDiv(stride, (int64_t) C10_WARP_SIZE*SZ));
|
||||
dim3 block(C10_WARP_SIZE, SZ);
|
||||
AT_DISPATCH_ALL_TYPES_AND2(
|
||||
|
|
@ -206,8 +208,8 @@ Tensor sparse_mask_helper_cuda(
|
|||
`t` - coalesced sparse tensor input
|
||||
`mask_indices` - mask indices tensor
|
||||
|
||||
Note: The nnz in the output tensor will be same as the `mask_indices`. So it will
|
||||
works independently if the mask is coalesced or not.
|
||||
Note: The nnz in the output tensor will be same as the `mask_indices`. So it will
|
||||
works independently if the mask is coalesced or not.
|
||||
*/
|
||||
TORCH_CHECK(t.is_sparse(), "t: input is not a sparse tensor");
|
||||
TORCH_CHECK(t.is_coalesced(), "t: input is uncoalesced");
|
||||
|
|
@ -221,12 +223,12 @@ Tensor sparse_mask_helper_cuda(
|
|||
auto vsize = t_values.sizes().vec();
|
||||
vsize[0] = r_nnz;
|
||||
|
||||
|
||||
|
||||
if (t.sparse_dim() == 0) {
|
||||
Tensor t_values_expand = t_values;
|
||||
t_values_expand = t_values_expand.expand(vsize).contiguous();
|
||||
return t_values_expand;
|
||||
}
|
||||
}
|
||||
Tensor r_values = at::zeros({vsize}, t_values.options());
|
||||
auto t_indices = t._indices().contiguous();
|
||||
auto t_nnz = t._nnz();
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/native/utils/Factory.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -12,7 +13,7 @@ Tensor empty_with_tail_padding(
|
|||
const c10::MemoryFormat memory_format,
|
||||
const DimnameList maybe_names) {
|
||||
auto* const allocator_ptr = c10::GetDefaultMobileCPUAllocator();
|
||||
const int64_t nelements = prod_intlist(size);
|
||||
const int64_t nelements = c10::multiply_integers(size);
|
||||
size_t size_bytes = nelements * dtype.itemsize();
|
||||
|
||||
Tensor tensor(c10::make_intrusive<c10::TensorImpl>(
|
||||
|
|
|
|||
|
|
@ -1,13 +1,7 @@
|
|||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/Utils.h>
|
||||
|
||||
#ifdef USE_VULKAN_WRAPPER
|
||||
#include <vulkan_wrapper.h>
|
||||
|
|
@ -25,6 +19,14 @@
|
|||
#include <ATen/native/vulkan/spv.h>
|
||||
#endif
|
||||
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
|
||||
#define VK_CHECK(f) \
|
||||
{ \
|
||||
VkResult res = (f); \
|
||||
|
|
@ -1169,7 +1171,7 @@ class VulkanTensor::Impl final {
|
|||
explicit Impl(std::vector<int64_t> sizes)
|
||||
: sizes_(std::move(sizes)),
|
||||
strides_(std::vector<int64_t>(sizes_.size())),
|
||||
numel_(prod_intlist(sizes_)) {
|
||||
numel_(c10::multiply_integers(sizes_)) {
|
||||
TORCH_CHECK(
|
||||
initVulkanContextOnce(), "Vulkan Failed to create Vulkan Context");
|
||||
}
|
||||
|
|
@ -1272,7 +1274,7 @@ class VulkanTensor::Impl final {
|
|||
|
||||
VkDeviceSize buffer_size_for_sizes(std::vector<int64_t> sizes) const {
|
||||
const auto d = sizes.size();
|
||||
const auto numel = prod_intlist(sizes);
|
||||
const auto numel = c10::multiply_integers(sizes);
|
||||
VkDeviceSize bufferSize{sizeof(float) * numel};
|
||||
// alignment to be able to copy between image and buffer
|
||||
if (d == 4) {
|
||||
|
|
|
|||
|
|
@ -1,17 +1,18 @@
|
|||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <ATen/native/vulkan/Vulkan.h>
|
||||
#include <ATen/native/vulkan/VulkanCommon.h>
|
||||
#include <ATen/native/vulkan/VulkanConvolution.h>
|
||||
#include <ATen/native/vulkan/VulkanOps.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace vulkan {
|
||||
|
|
@ -520,7 +521,7 @@ void add(
|
|||
void add(VulkanTensor& output, const VulkanTensor& input, const float s) {
|
||||
const auto sizes = input.sizes();
|
||||
|
||||
const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2);
|
||||
const auto C = c10::multiply_integers(sizes.cbegin(), sizes.cend() - 2);
|
||||
const auto C_4 = UP_DIV(C, 4);
|
||||
const auto H = sizes[2];
|
||||
const auto W = sizes[3];
|
||||
|
|
@ -567,7 +568,7 @@ void add(VulkanTensor& output, const VulkanTensor& input, const float s) {
|
|||
void mul(VulkanTensor& output, const VulkanTensor& input, const float s) {
|
||||
const auto sizes = input.sizes();
|
||||
|
||||
const auto C = prod_intlist(sizes.cbegin(), sizes.cend() - 2);
|
||||
const auto C = c10::multiply_integers(sizes.cbegin(), sizes.cend() - 2);
|
||||
const auto C_4 = UP_DIV(C, 4);
|
||||
const auto H = sizes[2];
|
||||
const auto W = sizes[3];
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <ATen/native/vulkan/ops/Tensor.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -165,7 +166,7 @@ VkDeviceSize buffer_bytes(
|
|||
size *= extents.data[0u] * extents.data[1u] * (4u * extents.data[2u]);
|
||||
}
|
||||
else {
|
||||
size *= prod_intlist(sizes);
|
||||
size *= c10::multiply_integers(sizes);
|
||||
}
|
||||
|
||||
return size;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include <ATen/native/vulkan/ops/Common.h>
|
||||
#include <ATen/native/vulkan/VulkanOpaqueTensorImpl.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
|
@ -507,7 +508,7 @@ inline IntArrayRef vTensor::sizes() const {
|
|||
|
||||
inline size_t vTensor::nbytes() const {
|
||||
return c10::elementSize(c10::typeMetaToScalarType(options().dtype())) *
|
||||
prod_intlist(sizes());
|
||||
c10::multiply_integers(sizes());
|
||||
}
|
||||
|
||||
inline IntArrayRef vTensor::strides() const {
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/native/quantized/affine_quantizer.h>
|
||||
#include <ATen/native/TensorFactories.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/quantized/QTensorImpl.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <typeinfo>
|
||||
|
||||
|
|
@ -106,7 +108,7 @@ inline Tensor new_qtensor(
|
|||
|
||||
at::DispatchKey tensorDispatchKey = options.computeDispatchKey();
|
||||
native::check_size_nonnegative(sizes);
|
||||
int64_t nelements = at::prod_intlist(sizes);
|
||||
int64_t nelements = c10::multiply_integers(sizes);
|
||||
auto dtype = options.dtype();
|
||||
TORCH_CHECK(
|
||||
isQIntType(typeMetaToScalarType(dtype)),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
|
@ -53,8 +56,7 @@ void test(DeprecatedTypeProperties &T) {
|
|||
ASSERT_EQ((size_t)t.ndimension(), s->size());
|
||||
ASSERT_TRUE(t.sizes().equals(*s));
|
||||
ASSERT_EQ(t.strides().size(), s->size());
|
||||
auto numel =
|
||||
std::accumulate(s->begin(), s->end(), 1, std::multiplies<int64_t>());
|
||||
const auto numel = c10::multiply_integers(s->begin(), s->end());
|
||||
ASSERT_EQ(t.numel(), numel);
|
||||
// verify we can output
|
||||
std::stringstream ss;
|
||||
|
|
|
|||
|
|
@ -6,19 +6,18 @@
|
|||
#include <numeric>
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/core/CopyBytes.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/core/impl/SizesAndStrides.h>
|
||||
#include <c10/core/CopyBytes.h>
|
||||
|
||||
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/Flags.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/python_stub.h>
|
||||
|
||||
// A global boolean variable to control whether we free memory when a Tensor
|
||||
|
|
@ -1111,11 +1110,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
Resize(newDims);
|
||||
return;
|
||||
}
|
||||
auto newNumel = std::accumulate(
|
||||
newDims.begin(),
|
||||
newDims.end(),
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
const auto newNumel = c10::multiply_integers(newDims.begin(), newDims.end());
|
||||
if (newNumel * data_type_.itemsize() <= storage_.nbytes()) {
|
||||
sizes_and_strides_.set_sizes(newDims);
|
||||
numel_ = newNumel;
|
||||
|
|
@ -1172,11 +1167,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
// TODO: eliminate newCapacity.
|
||||
SmallVector<int64_t, 5> newCapacity(sizes_and_strides_.sizes_begin(), sizes_and_strides_.sizes_end());
|
||||
newCapacity[0] = outer_dim;
|
||||
auto newNumel = std::accumulate(
|
||||
newCapacity.begin(),
|
||||
newCapacity.end(),
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>());
|
||||
auto newNumel = c10::multiply_integers(newCapacity);
|
||||
if (newNumel * data_type_.itemsize() <= storage_.nbytes()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,20 @@
|
|||
#ifndef CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
|
||||
#define CAFFE2_OPERATORS_BATCH_MATMUL_OP_H_
|
||||
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <fbgemm/FbgemmConvert.h>
|
||||
|
||||
#include "caffe2/contrib/fakelowp/fp16_gemm_utils.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <fbgemm/FbgemmConvert.h>
|
||||
#include "caffe2/contrib/fakelowp/fp16_gemm_utils.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
C10_DECLARE_bool(caffe2_fbgemm_fake_fp16_clamp);
|
||||
|
||||
namespace caffe2 {
|
||||
|
|
@ -266,21 +269,15 @@ class BatchMatMulFP16FakeOp final : public Operator<Context> {
|
|||
CAFFE_ENFORCE(broadcast_);
|
||||
}
|
||||
|
||||
const std::int64_t A_batch_size = std::accumulate(
|
||||
const std::int64_t A_batch_size = c10::multiply_integers(
|
||||
A_broadcast_dims.cbegin(),
|
||||
A_broadcast_dims.cbegin() + batch_dim,
|
||||
1LL,
|
||||
std::multiplies<std::int64_t>());
|
||||
const std::int64_t B_batch_size = std::accumulate(
|
||||
A_broadcast_dims.cbegin() + batch_dim);
|
||||
const std::int64_t B_batch_size = c10::multiply_integers(
|
||||
B_broadcast_dims.cbegin(),
|
||||
B_broadcast_dims.cbegin() + batch_dim,
|
||||
1LL,
|
||||
std::multiplies<std::int64_t>());
|
||||
const std::int64_t Y_batch_size = std::accumulate(
|
||||
B_broadcast_dims.cbegin() + batch_dim);
|
||||
const std::int64_t Y_batch_size = c10::multiply_integers(
|
||||
Y_broadcast_dims.cbegin(),
|
||||
Y_broadcast_dims.cbegin() + batch_dim,
|
||||
1LL,
|
||||
std::multiplies<std::int64_t>());
|
||||
Y_broadcast_dims.cbegin() + batch_dim);
|
||||
if (Y_batch_size == 0) {
|
||||
fbgemm::RoundToFloat16(
|
||||
reinterpret_cast<const float*>(Y_data),
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
#include "caffe2/contrib/tensorrt/tensorrt_op_trt.h"
|
||||
|
||||
#include <numeric>
|
||||
#include <unordered_map>
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/contrib/tensorrt/tensorrt_tranformer.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "onnx/onnx_pb.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <numeric>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
|
@ -134,13 +135,8 @@ void TensorRTOp::MaybeAdjustOutputShape(
|
|||
const auto it = output_size_hints_.find(output_idx);
|
||||
if (it != output_size_hints_.end()) {
|
||||
const auto& dims_hint = it->second;
|
||||
auto total_trt = std::accumulate(
|
||||
dims->begin(), dims->end(), (int64_t)(1), std::multiplies<int64_t>());
|
||||
auto total_c2 = std::accumulate(
|
||||
dims_hint.begin(),
|
||||
dims_hint.end(),
|
||||
(int64_t)(1),
|
||||
std::multiplies<int64_t>());
|
||||
const auto total_trt = c10::multiply_integers(*dims);
|
||||
const auto total_c2 = c10::multiply_integers(*dims_hint);
|
||||
CAFFE_ENFORCE_EQ(
|
||||
total_trt,
|
||||
total_c2,
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
#ifndef CAFFE2_CORE_QTENSOR_H_
|
||||
#define CAFFE2_CORE_QTENSOR_H_
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10/util/typeid.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <climits>
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
#include <c10/util/typeid.h>
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
|
|
@ -57,8 +58,7 @@ class C10_EXPORT QTensor {
|
|||
|
||||
void Resize(at::ArrayRef<int> dim_source) {
|
||||
if (dims_ != dim_source) {
|
||||
size_t source_size = std::accumulate(
|
||||
dim_source.begin(), dim_source.end(), 1, std::multiplies<int>());
|
||||
const auto source_size = c10::multiply_integers(dim_source);
|
||||
if ((source_size * (precision_ + signed_)) > capacity_) {
|
||||
data_ptr_.clear();
|
||||
capacity_ = 0;
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
#include "caffe2/operators/cosh_op.h"
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
|
@ -34,8 +35,7 @@ bool CoshGradientFunctor<CUDAContext>::Forward(
|
|||
const T* X,
|
||||
T* dX,
|
||||
CUDAContext* context) const {
|
||||
const int size = std::accumulate(
|
||||
X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto size = c10::multiply_integers(X_dims.cbegin(), X_dims.cend());
|
||||
CoshGradientCUDAKernel<<<
|
||||
CAFFE_GET_BLOCKS(size),
|
||||
CAFFE_CUDA_NUM_THREADS,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include "caffe2/operators/elementwise_mul_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -21,12 +23,9 @@ void ComputeMulGradient(
|
|||
TGrad* dA,
|
||||
TGrad* dB,
|
||||
CPUContext* context) {
|
||||
const int A_size =
|
||||
std::accumulate(A_dims, A_dims + ndim, 1, std::multiplies<int>());
|
||||
const int B_size =
|
||||
std::accumulate(B_dims, B_dims + ndim, 1, std::multiplies<int>());
|
||||
const int C_size =
|
||||
std::accumulate(C_dims, C_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto A_size = c10::multiply_integers(A_dims, A_dims + ndim);
|
||||
const auto B_size = c10::multiply_integers(B_dims, B_dims + ndim);
|
||||
const auto C_size = c10::multiply_integers(C_dims, C_dims + ndim);
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
|
|
@ -96,8 +95,7 @@ bool MulFunctor<CPUContext>::Backward(
|
|||
TGrad* dB,
|
||||
CPUContext* context) const {
|
||||
if (A_dims == B_dims) {
|
||||
const int size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto size = c10::multiply_integers(A_dims);
|
||||
math::Mul(size, dC, B, dA, context);
|
||||
math::Mul(size, dC, A, dB, context);
|
||||
return true;
|
||||
|
|
@ -126,10 +124,8 @@ bool MulFunctor<CPUContext>::Backward(
|
|||
1,
|
||||
std::multiplies<int>());
|
||||
if (C_size == 0) {
|
||||
const int A_size = std::accumulate(
|
||||
A_dims.cbegin(), A_dims.cend(), 1, std::multiplies<int>());
|
||||
const int B_size = std::accumulate(
|
||||
B_dims.cbegin(), B_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto A_size = c10::multiply_integers(A_dims);
|
||||
const auto B_size = c10::multiply_integers(B_dims);
|
||||
math::Set<TGrad, CPUContext>(A_size, TGrad(0), dA, context);
|
||||
math::Set<TGrad, CPUContext>(B_size, TGrad(0), dB, context);
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
#include "caffe2/operators/pool_op.h"
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
#include <array>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "caffe2/core/context_gpu.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
|
@ -835,8 +836,7 @@ bool AveragePoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
|
|||
CUDAContext* context) const {
|
||||
// Each CUDA block handles one point, one thread per channel.
|
||||
const int ndim = X_dims.size();
|
||||
const int Y_HxW = std::accumulate(
|
||||
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto Y_HxW = c10::multiply_integers(Y_dims.cbegin(), Y_dims.cend());
|
||||
switch (ndim) {
|
||||
case 1: {
|
||||
AveragePool1DForwardNHWCCUDAKernel<float>
|
||||
|
|
@ -1064,8 +1064,7 @@ bool AveragePoolFunctor<CUDAContext>::Backward<float, StorageOrder::NHWC>(
|
|||
float* dX,
|
||||
CUDAContext* context) const {
|
||||
const int ndim = X_dims.size();
|
||||
const int X_HxW = std::accumulate(
|
||||
X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto X_HxW = c10::multiply_integers(X_dims.cbegin(), X_dims.cend());
|
||||
const int num_blocks = N * X_HxW;
|
||||
switch (ndim) {
|
||||
case 1: {
|
||||
|
|
@ -1867,8 +1866,7 @@ bool MaxPoolFunctor<CUDAContext>::Forward<float, StorageOrder::NHWC>(
|
|||
CUDAContext* context) const {
|
||||
// Each CUDA block handles one point, one thread per channel.
|
||||
const int ndim = X_dims.size();
|
||||
const int Y_HxW = std::accumulate(
|
||||
Y_dims.cbegin(), Y_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto Y_HxW = c10::multiply_integers(Y_dims.cbegin(), Y_dims.cend());
|
||||
switch (ndim) {
|
||||
case 1: {
|
||||
MaxPool1DForwardNHWCCUDAKernel<float>
|
||||
|
|
@ -2063,8 +2061,7 @@ bool MaxPoolFunctor<CUDAContext>::Backward<float, StorageOrder::NHWC>(
|
|||
float* dX,
|
||||
CUDAContext* context) const {
|
||||
const int ndim = X_dims.size();
|
||||
const int X_HxW = std::accumulate(
|
||||
X_dims.cbegin(), X_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto X_HxW = c10::multiply_integers(X_dims.cbegin(), X_dims.cend());
|
||||
switch (ndim) {
|
||||
case 1: {
|
||||
MaxPool1DBackwardNHWCCUDAKernel<float>
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
#include "caffe2/operators/reduce_ops.h"
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/utils/math.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
namespace {
|
||||
|
|
@ -18,8 +19,7 @@ void ComputeReduceMinMaxGradient(
|
|||
const T* X_data,
|
||||
const T* Y_data,
|
||||
T* dX_data) {
|
||||
const int dX_size = std::accumulate(
|
||||
dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto dX_size = c10::multiply_integers(dX_dims.cbegin(), dX_dims.cend());
|
||||
const int ndim = dX_dims.size();
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int dX_index = 0; dX_index < dX_size; ++dX_index) {
|
||||
|
|
@ -342,8 +342,7 @@ bool L1Reducer<CPUContext>::Backward(
|
|||
T* dX_data,
|
||||
CPUContext* /* context */) const {
|
||||
const float kEps = 1e-12f;
|
||||
const int dX_size = std::accumulate(
|
||||
dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto dX_size = c10::multiply_integers(dX_dims.cbegin(), dX_dims.cend());
|
||||
const int ndim = dX_dims.size();
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int dX_index = 0; dX_index < dX_size; ++dX_index) {
|
||||
|
|
@ -373,8 +372,7 @@ bool L2Reducer<CPUContext>::Backward(
|
|||
T* dX_data,
|
||||
CPUContext* /* context */) const {
|
||||
const float kEps = 1e-12f;
|
||||
const int dX_size = std::accumulate(
|
||||
dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
|
||||
const auto dX_size = c10::multiply_integers(dX_dims.cbegin(), dX_dims.cend());
|
||||
const int ndim = dX_dims.size();
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int dX_index = 0; dX_index < dX_size; ++dX_index) {
|
||||
|
|
|
|||
|
|
@ -1,11 +1,5 @@
|
|||
#include "caffe2/utils/math/reduce.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#ifdef CAFFE2_USE_ACCELERATE
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#endif // CAFFE2_USE_ACCELERATE
|
||||
|
|
@ -14,12 +8,19 @@
|
|||
#include <mkl.h>
|
||||
#endif // CAFFE2_USE_MKL
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/utils/eigen_utils.h"
|
||||
#include "caffe2/utils/math.h"
|
||||
#include "caffe2/utils/math/elementwise.h"
|
||||
#include "caffe2/utils/math/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
namespace caffe2 {
|
||||
namespace math {
|
||||
|
||||
|
|
@ -265,10 +266,8 @@ void ReduceTensorImpl(
|
|||
const T* X,
|
||||
T* Y,
|
||||
CPUContext* context) {
|
||||
const int X_size =
|
||||
std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>());
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto X_size = c10::multiply_integers(X_dims, X_dims + ndim);
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Set<T, CPUContext>(Y_size, init, Y, context);
|
||||
std::vector<int> index(ndim, 0);
|
||||
for (int X_index = 0; X_index < X_size; ++X_index) {
|
||||
|
|
@ -296,8 +295,7 @@ void ReduceMinImpl(
|
|||
X,
|
||||
Y,
|
||||
context);
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Scale<T, T, CPUContext>(Y_size, alpha, Y, Y, context);
|
||||
}
|
||||
|
||||
|
|
@ -319,8 +317,7 @@ void ReduceMaxImpl(
|
|||
X,
|
||||
Y,
|
||||
context);
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Scale<T, T, CPUContext>(Y_size, alpha, Y, Y, context);
|
||||
}
|
||||
|
||||
|
|
@ -334,8 +331,7 @@ void ReduceSumImpl(
|
|||
T* Y,
|
||||
CPUContext* context) {
|
||||
ReduceTensorImpl(ndim, X_dims, Y_dims, std::plus<T>(), T(0), X, Y, context);
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Scale<T, T, CPUContext>(Y_size, alpha, Y, Y, context);
|
||||
}
|
||||
|
||||
|
|
@ -349,10 +345,8 @@ void ReduceMeanImpl(
|
|||
T* Y,
|
||||
CPUContext* context) {
|
||||
ReduceTensorImpl(ndim, X_dims, Y_dims, std::plus<T>(), T(0), X, Y, context);
|
||||
const int X_size =
|
||||
std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>());
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto X_size = c10::multiply_integers(X_dims, X_dims + ndim);
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Scale<T, T, CPUContext>(
|
||||
Y_size,
|
||||
alpha * static_cast<T>(Y_size) / static_cast<T>(X_size),
|
||||
|
|
@ -379,8 +373,7 @@ void ReduceL1Impl(
|
|||
X,
|
||||
Y,
|
||||
context);
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
Scale<T, T, CPUContext>(Y_size, alpha, Y, Y, context);
|
||||
}
|
||||
|
||||
|
|
@ -402,8 +395,7 @@ void ReduceL2Impl(
|
|||
X,
|
||||
Y,
|
||||
context);
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
EigenVectorArrayMap<T> Y_arr(Y, Y_size);
|
||||
Y_arr = Y_arr.sqrt() * alpha;
|
||||
}
|
||||
|
|
@ -479,10 +471,8 @@ void MomentsImpl(
|
|||
T* mean,
|
||||
T* var,
|
||||
CPUContext* /* context */) {
|
||||
const int X_size =
|
||||
std::accumulate(X_dims, X_dims + ndim, 1, std::multiplies<int>());
|
||||
const int Y_size =
|
||||
std::accumulate(Y_dims, Y_dims + ndim, 1, std::multiplies<int>());
|
||||
const auto X_size = c10::multiply_integers(X_dims, X_dims + ndim);
|
||||
const auto Y_size = c10::multiply_integers(Y_dims, Y_dims + ndim);
|
||||
if (X_size == 0) {
|
||||
std::memset(mean, 0, sizeof(T) * Y_size);
|
||||
std::memset(var, 0, sizeof(T) * Y_size);
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#include <c10/util/accumulate.h>
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/utils/math/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
|
@ -5,8 +7,6 @@
|
|||
#include <numeric>
|
||||
#include <vector>
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace math {
|
||||
namespace utils {
|
||||
|
|
@ -183,12 +183,10 @@ bool IsRowwiseBroadcastBinaryOp(
|
|||
}
|
||||
const int pivot = std::max(A_pivot, B_pivot);
|
||||
if (A_pivot > B_pivot) {
|
||||
*rows = std::accumulate(
|
||||
B_dims + B_pivot, B_dims + pivot, 1, std::multiplies<int>());
|
||||
*rows = c10::multiply_integers(B_dims + B_pivot, B_dims + pivot);
|
||||
*broadcast_1st = true;
|
||||
} else {
|
||||
*rows = std::accumulate(
|
||||
A_dims + A_pivot, A_dims + pivot, 1, std::multiplies<int>());
|
||||
*rows = c10::multiply_integers(A_dims + A_pivot, A_dims + pivot);
|
||||
*broadcast_1st = false;
|
||||
}
|
||||
*cols = 1;
|
||||
|
|
@ -224,12 +222,10 @@ bool IsColwiseBroadcastBinaryOp(
|
|||
++B_pivot;
|
||||
const int pivot = std::min(A_pivot, B_pivot);
|
||||
if (A_pivot < B_pivot) {
|
||||
*cols = std::accumulate(
|
||||
B_dims + pivot, B_dims + B_pivot, 1, std::multiplies<int>());
|
||||
*cols = c10::multiply_integers(B_dims + pivot, B_dims + B_pivot);
|
||||
*broadcast_1st = true;
|
||||
} else {
|
||||
*cols = std::accumulate(
|
||||
A_dims + pivot, A_dims + A_pivot, 1, std::multiplies<int>());
|
||||
*cols = c10::multiply_integers(A_dims + pivot, A_dims + A_pivot);
|
||||
*broadcast_1st = false;
|
||||
}
|
||||
*rows = 1;
|
||||
|
|
@ -271,16 +267,12 @@ bool IsBothEndsBroadcastBinaryOp(
|
|||
return false;
|
||||
}
|
||||
if (A_pre > B_pre && A_nxt < B_nxt) {
|
||||
*pre = std::accumulate(
|
||||
B_dims + B_pre, B_dims + A_pre, 1, std::multiplies<int>());
|
||||
*nxt = std::accumulate(
|
||||
B_dims + A_nxt, B_dims + B_nxt, 1, std::multiplies<int>());
|
||||
*pre = c10::multiply_integers(B_dims + B_pre, B_dims + A_pre);
|
||||
*nxt = c10::multiply_integers(B_dims + A_nxt, B_dims + B_nxt);
|
||||
*broadcast_1st = true;
|
||||
} else if (A_pre < B_pre && A_nxt > B_nxt) {
|
||||
*pre = std::accumulate(
|
||||
A_dims + A_pre, A_dims + B_pre, 1, std::multiplies<int>());
|
||||
*nxt = std::accumulate(
|
||||
A_dims + B_nxt, A_dims + A_nxt, 1, std::multiplies<int>());
|
||||
*pre = c10::multiply_integers(A_dims + A_pre, A_dims + B_pre);
|
||||
*nxt = c10::multiply_integers(A_dims + B_nxt, A_dims + A_nxt);
|
||||
*broadcast_1st = false;
|
||||
} else {
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -2,19 +2,19 @@
|
|||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/BatchedTensorImpl.h>
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/core/Reduction.h>
|
||||
#include <ATen/BatchedTensorImpl.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/native/LinearAlgebraUtils.h>
|
||||
#include <ATen/SparseTensorUtils.h>
|
||||
#include <ATen/native/IndexingUtils.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
|
||||
#include <ciso646>
|
||||
#include <algorithm>
|
||||
|
|
@ -2736,9 +2736,9 @@ infinitely_differentiable_native_layer_norm_backward(
|
|||
const auto input_ndim = X.dim();
|
||||
const int axis = input_ndim - normalized_ndim;
|
||||
const int64_t M =
|
||||
at::prod_intlist(input_shape.cbegin(), input_shape.cbegin() + axis);
|
||||
c10::multiply_integers(input_shape.cbegin(), input_shape.cbegin() + axis);
|
||||
const int64_t N =
|
||||
at::prod_intlist(input_shape.cbegin() + axis, input_shape.cend());
|
||||
c10::multiply_integers(input_shape.cbegin() + axis, input_shape.cend());
|
||||
|
||||
Tensor dX;
|
||||
Tensor dgamma;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/Utils.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/csrc/jit/ir/constants.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
|
|
@ -208,7 +209,7 @@ Tensor resizeConstantScalarOrTensorToShape(
|
|||
ret_tensor = ret_tensor.reshape({1});
|
||||
ret_tensor = ret_tensor.expand(shape);
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(ret_tensor.numel() == at::prod_intlist(shape));
|
||||
TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
|
||||
ret_tensor = ret_tensor.view(shape);
|
||||
}
|
||||
return ret_tensor;
|
||||
|
|
|
|||
|
|
@ -1,26 +1,26 @@
|
|||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Utils.h>
|
||||
#include <ATen/core/functional.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <torch/csrc/autograd/symbolic.h>
|
||||
#include <torch/csrc/jit/jit_log.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_constants.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_functions.h>
|
||||
#include <torch/csrc/jit/serialization/import_export_helpers.h>
|
||||
#include <torch/csrc/jit/serialization/onnx.h>
|
||||
#include <torch/csrc/onnx/onnx.h>
|
||||
|
||||
#include <ATen/core/functional.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <torch/csrc/jit/passes/dead_code_elimination.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/runtime/instruction.h>
|
||||
|
||||
#include <onnx/checker.h>
|
||||
#include <onnx/onnx_pb.h>
|
||||
#include <onnx/proto_utils.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <regex>
|
||||
|
|
@ -839,11 +839,8 @@ void GraphEncoder::EncodeTensor(
|
|||
tensor_proto->set_raw_data("__EXTERNAL");
|
||||
} else {
|
||||
AT_ASSERT(t.is_contiguous());
|
||||
size_t tensorSize = static_cast<size_t>(std::accumulate(
|
||||
std::begin(tensor.sizes()),
|
||||
std::end(tensor.sizes()),
|
||||
static_cast<int64_t>(1),
|
||||
std::multiplies<int64_t>()));
|
||||
size_t tensorSize = static_cast<size_t>(c10::multiply_integers(
|
||||
std::begin(tensor.sizes()), std::end(tensor.sizes())));
|
||||
if (use_external_data_format &&
|
||||
tensorSize > ParamSizeThresholdForExternalStorage) {
|
||||
AT_ASSERT(!onnx_file_path.empty());
|
||||
|
|
|
|||
|
|
@ -1,19 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <c10/util/accumulate.h>
|
||||
#include <c10d/Types.hpp>
|
||||
|
||||
#ifdef _WIN32
|
||||
|
|
@ -29,6 +17,18 @@ typedef SSIZE_T ssize_t;
|
|||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <system_error>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace c10d {
|
||||
|
||||
// Turns at::IntArrayRef into "(1, 2, 3, 4)".
|
||||
|
|
@ -405,7 +405,7 @@ inline void checkSplitSizes(
|
|||
TORCH_CHECK(
|
||||
split_sizes.size() == group_size,
|
||||
"Number of tensor splits not equal to group size");
|
||||
int sum = std::accumulate(split_sizes.begin(), split_sizes.end(), 0);
|
||||
const auto sum = c10::sum_integers(split_sizes);
|
||||
TORCH_CHECK(
|
||||
sum == tensor.size(0), "Split sizes doesn't match total dim 0 size");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user