chunk_size should always be int64_t for Foreach functors (#156872)

See https://github.com/pytorch/pytorch/issues/156261#issuecomment-3002394773

Testing is a valid q--it is pretty expensive to test such large tensors for all these ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156872
Approved by: https://github.com/Skylion007, https://github.com/eqy
ghstack dependencies: #156876, #156871
This commit is contained in:
Jane Xu 2025-06-27 09:26:47 -07:00 committed by PyTorch MergeBot
parent 5a0926a26e
commit d283fc79b1
4 changed files with 16 additions and 16 deletions

View File

@ -208,7 +208,7 @@ struct BinaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op, Op op,
opmath_t scalar) { opmath_t scalar) {
@ -232,7 +232,7 @@ struct BinaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl, TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) { Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -256,7 +256,7 @@ struct BinaryOpListAlphaFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op, Op op,
opmath_t alpha) { opmath_t alpha) {
@ -308,7 +308,7 @@ struct BinaryOpScalarTensorFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op, Op op,
T* scalar, T* scalar,
@ -364,7 +364,7 @@ struct BinaryOpScalarTensorFunctor {
template <typename T, int depth, int r_args_depth, int res_arg_index> template <typename T, int depth, int r_args_depth, int res_arg_index>
struct ZeroFunctor { struct ZeroFunctor {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<1>& tl) { TensorListMetadata<1>& tl) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
@ -406,7 +406,7 @@ struct UnaryOpFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op) { Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -458,7 +458,7 @@ struct PointwiseOpScalarFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op, Op op,
opmath_t scalar) { opmath_t scalar) {
@ -482,7 +482,7 @@ struct PointwiseOpScalarListFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl, TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) { Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -506,7 +506,7 @@ struct PointwiseOpListFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op) { Op op) {
const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
@ -557,7 +557,7 @@ struct TernaryOpListFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op) { Op op) {
static_assert(depth == 3 || depth == 4, ""); static_assert(depth == 3 || depth == 4, "");
@ -611,7 +611,7 @@ struct TernaryOpScalarFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
Op op, Op op,
opmath_t alpha) { opmath_t alpha) {
@ -668,7 +668,7 @@ struct TernaryOpScalarListFunctor {
using opmath_t = at::opmath_type<T>; using opmath_t = at::opmath_type<T>;
template <typename Op> template <typename Op>
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListScalarListMetadata<opmath_t, depth>& tl, TensorListScalarListMetadata<opmath_t, depth>& tl,
Op op) { Op op) {
static_assert(depth == 2 || depth == 3, ""); static_assert(depth == 2 || depth == 3, "");

View File

@ -53,7 +53,7 @@ template <
int res_arg_index = 0> int res_arg_index = 0>
struct LpMaxFunctor { struct LpMaxFunctor {
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
T* output_per_tensor_ptr, T* output_per_tensor_ptr,
const int max_chunks_per_tensor) { const int max_chunks_per_tensor) {
@ -243,7 +243,7 @@ template <
struct LpNormFunctor { struct LpNormFunctor {
using out_opmath_t = typename at::opmath_type<out_t>; using out_opmath_t = typename at::opmath_type<out_t>;
__device__ __forceinline__ void operator()( __device__ __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
out_opmath_t* output_per_tensor_ptr, out_opmath_t* output_per_tensor_ptr,
const int max_chunks_per_tensor) { const int max_chunks_per_tensor) {

View File

@ -62,7 +62,7 @@ struct FusedSgdMathFunctor {
depth == 2 || depth == 3, depth == 2 || depth == 3,
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
C10_DEVICE __forceinline__ void operator()( C10_DEVICE __forceinline__ void operator()(
const int chunk_size, const int64_t chunk_size,
TensorListMetadata<depth>& tl, TensorListMetadata<depth>& tl,
const double weight_decay, const double weight_decay,
const double momentum, const double momentum,

View File

@ -108,7 +108,7 @@ struct FusedAdamMathFunctor {
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); "depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
using opmath_t = at::opmath_type<scalar_type>; using opmath_t = at::opmath_type<scalar_type>;
C10_DEVICE __forceinline__ void operator()( C10_DEVICE __forceinline__ void operator()(
int chunk_size, int64_t chunk_size,
FusedOptimizerTensorListMetadata<depth>& tl, FusedOptimizerTensorListMetadata<depth>& tl,
const float* lr_ptr, const float* lr_ptr,
const double& lr, const double& lr,