mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
chunk_size should always be int64_t for Foreach functors (#156872)
See https://github.com/pytorch/pytorch/issues/156261#issuecomment-3002394773 Testing is a valid q--it is pretty expensive to test such large tensors for all these ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156872 Approved by: https://github.com/Skylion007, https://github.com/eqy ghstack dependencies: #156876, #156871
This commit is contained in:
parent
5a0926a26e
commit
d283fc79b1
|
|
@ -208,7 +208,7 @@ struct BinaryOpScalarFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
|
|
@ -232,7 +232,7 @@ struct BinaryOpScalarListFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
|
@ -256,7 +256,7 @@ struct BinaryOpListAlphaFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t alpha) {
|
||||
|
|
@ -308,7 +308,7 @@ struct BinaryOpScalarTensorFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
T* scalar,
|
||||
|
|
@ -364,7 +364,7 @@ struct BinaryOpScalarTensorFunctor {
|
|||
template <typename T, int depth, int r_args_depth, int res_arg_index>
|
||||
struct ZeroFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<1>& tl) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
|
|
@ -406,7 +406,7 @@ struct UnaryOpFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
|
@ -458,7 +458,7 @@ struct PointwiseOpScalarFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
|
|
@ -482,7 +482,7 @@ struct PointwiseOpScalarListFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
|
@ -506,7 +506,7 @@ struct PointwiseOpListFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
|
|
@ -557,7 +557,7 @@ struct TernaryOpListFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
static_assert(depth == 3 || depth == 4, "");
|
||||
|
|
@ -611,7 +611,7 @@ struct TernaryOpScalarFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t alpha) {
|
||||
|
|
@ -668,7 +668,7 @@ struct TernaryOpScalarListFunctor {
|
|||
using opmath_t = at::opmath_type<T>;
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
static_assert(depth == 2 || depth == 3, "");
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ template <
|
|||
int res_arg_index = 0>
|
||||
struct LpMaxFunctor {
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
T* output_per_tensor_ptr,
|
||||
const int max_chunks_per_tensor) {
|
||||
|
|
@ -243,7 +243,7 @@ template <
|
|||
struct LpNormFunctor {
|
||||
using out_opmath_t = typename at::opmath_type<out_t>;
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
out_opmath_t* output_per_tensor_ptr,
|
||||
const int max_chunks_per_tensor) {
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ struct FusedSgdMathFunctor {
|
|||
depth == 2 || depth == 3,
|
||||
"depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0");
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
const int chunk_size,
|
||||
const int64_t chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
const double weight_decay,
|
||||
const double momentum,
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ struct FusedAdamMathFunctor {
|
|||
"depth of 4 for Adam, depth of 5 for Adam with AMSGrad.");
|
||||
using opmath_t = at::opmath_type<scalar_type>;
|
||||
C10_DEVICE __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
int64_t chunk_size,
|
||||
FusedOptimizerTensorListMetadata<depth>& tl,
|
||||
const float* lr_ptr,
|
||||
const double& lr,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user