chunk_size should always be int64_t for Foreach functors (#156872)

See https://github.com/pytorch/pytorch/issues/156261#issuecomment-3002394773

Testing is a valid q--it is pretty expensive to test such large tensors for all these ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156872
Approved by: https://github.com/Skylion007, https://github.com/eqy
ghstack dependencies: #156876, #156871
This commit is contained in:
Jane Xu 2025-06-27 09:26:47 -07:00 committed by PyTorch MergeBot
parent 5a0926a26e
commit d283fc79b1
4 changed files with 16 additions and 16 deletions

View File

@ -208,7 +208,7 @@ struct BinaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
@ -232,7 +232,7 @@ struct BinaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -256,7 +256,7 @@ struct BinaryOpListAlphaFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t alpha) {
@ -308,7 +308,7 @@ struct BinaryOpScalarTensorFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op,
T* scalar,
@ -364,7 +364,7 @@ struct BinaryOpScalarTensorFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor {
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<1>& tl) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -406,7 +406,7 @@ struct UnaryOpFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -458,7 +458,7 @@ struct PointwiseOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t scalar) {
@ -482,7 +482,7 @@ struct PointwiseOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -506,7 +506,7 @@ struct PointwiseOpListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -557,7 +557,7 @@ struct TernaryOpListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op) {
static_assert(depth == 3 || depth == 4, "");
@ -611,7 +611,7 @@ struct TernaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
Op op,
opmath_t alpha) {
@ -668,7 +668,7 @@ struct TernaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>;
template <typename Op>
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) {
static_assert(depth == 2 || depth == 3, "");

View File

@ -53,7 +53,7 @@ template <
int res_arg_index = 0>
struct LpMaxFunctor {
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
T* output_per_tensor_ptr,
const int max_chunks_per_tensor) {
@ -243,7 +243,7 @@ template <
struct LpNormFunctor {
using out_opmath_t = typename at::opmath_type<out_t>;
__device__ __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
TensorListMetadata<depth>& tl,
out_opmath_t* output_per_tensor_ptr,
const int max_chunks_per_tensor) {

View File

@ -62,7 +62,7 @@ struct FusedSgdMathFunctor {
depth == 2 || depth == 3,
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
C10_DEVICE __forceinline__ void operator()(
const int chunk_size,
const int64_t chunk_size,
TensorListMetadata<depth>& tl,
const double weight_decay,
const double momentum,

View File

@ -108,7 +108,7 @@ struct FusedAdamMathFunctor {
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
using opmath_t = at::opmath_type<scalar_type>;
C10_DEVICE __forceinline__ void operator()(
int chunk_size,
int64_t chunk_size,
FusedOptimizerTensorListMetadata<depth>& tl,
const float* lr_ptr,
const double& lr,