mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Use int64_t for indexing in multi_tensor_apply (#101760)
Fixes #101449 I found it better to either imitate the combo of `TensorIterator::can_use_32bit_indexing` and `TensorIterator::with_32bit_indexing` or adroitly choose the index type depending on `Tensor::numel` in the future. --- Used `nsys nvprof` to casually see the effect of `int64_t` indexing: ```python import torch params = [ {"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]}, {"params": [torch.randn(32, 32, device="cuda") for _ in range(100)]}, ] grads = [ [torch.randn(32, 32, device="cuda") for _ in range(100)], [torch.randn(32, 32, device="cuda") for _ in range(100)], ] optimizer = torch.optim.Adam(params, fused=True) for _ in range(100): for i, param_groups in enumerate(params): for p, g in zip(param_groups["params"], grads[i]): p.grad = g optimizer.step() optimizer.zero_grad() ``` Environment ``` Collecting environment information... PyTorch version: 2.1.0a0+gitf994d0b Is debug build: False CUDA used to build PyTorch: 12.1 Python version: 3.10.9 (main, May 17 2023, 00:46:40) [GCC 11.3.0] (64-bit runtime) CUDA runtime version: 12.1.105 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA A100-SXM4-80GB ``` --- - `multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensor` -> 1.02x - `multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in…` -> 1.04x Current main branch: ``` Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 64.9 5787610 600 9646.0 9632.0 9503 9888 52.9 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi… ... 8.1 720575 200 3602.9 3584.0 3551 4320 63.4 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in… ``` this PR: ``` Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name -------- --------------- --------- -------- -------- -------- -------- ----------- ---------------------------------------------------------------------------------------------------- 65.0 5876847 600 9794.7 9792.0 9632 10080 58.1 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::FusedOptimizerTensorLi… ... 8.3 748313 200 3741.6 3744.0 3711 4479 60.0 void at::native::<unnamed>::multi_tensor_apply_kernel<at::native::<unnamed>::TensorListMetadata<(in… ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/101760 Approved by: https://github.com/ngimel
This commit is contained in:
parent
b8e2e0e907
commit
401109a243
|
|
@ -21,9 +21,9 @@ template<int depth, typename T>
|
|||
__device__ bool init_args(
|
||||
T** args,
|
||||
TensorListMetadata<depth>& tl,
|
||||
int chunk_idx,
|
||||
int chunk_size,
|
||||
int tensor_loc) {
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
bool all_aligned = true;
|
||||
for (int i = 0; i < depth; i++) {
|
||||
args[i] = (T*)tl.addresses[i][tensor_loc];
|
||||
|
|
@ -41,9 +41,9 @@ template<int depth, typename T, typename T2>
|
|||
__device__ bool init_args(
|
||||
T** args,
|
||||
TensorListScalarListMetadata<T2, depth>& tl,
|
||||
int chunk_idx,
|
||||
int chunk_size,
|
||||
int tensor_loc) {
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
bool all_aligned = true;
|
||||
for (int i = 0; i < depth; i++) {
|
||||
args[i] = (T*)tl.addresses[i][tensor_loc];
|
||||
|
|
@ -60,9 +60,9 @@ template<int depth, typename T>
|
|||
__device__ bool init_args(
|
||||
T** args,
|
||||
FusedOptimizerTensorListMetadata<depth>& tl,
|
||||
int chunk_idx,
|
||||
int chunk_size,
|
||||
int tensor_loc) {
|
||||
const int64_t chunk_idx,
|
||||
const int64_t chunk_size,
|
||||
const int64_t tensor_loc) {
|
||||
bool all_aligned = true;
|
||||
for (int i = 0; i < depth; i++) {
|
||||
args[i] = (T*)tl.addresses[i][tensor_loc];
|
||||
|
|
@ -76,10 +76,10 @@ __device__ bool init_args(
|
|||
}
|
||||
|
||||
template<int depth, typename T>
|
||||
__device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_size, int n) {
|
||||
__device__ void load_args(T r_args[][kILP], T** args, const int64_t i_start, const int64_t chunk_size, const int64_t n) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
const auto i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
for (int r_index = 0; r_index < depth; r_index++) {
|
||||
r_args[r_index][ii] = 0;
|
||||
if(i < n && i < chunk_size) {
|
||||
|
|
@ -90,10 +90,10 @@ __device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_siz
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ void store_args(T* dst, T* src, int i_start, int chunk_size, int n) {
|
||||
__device__ void store_args(T* dst, T* src, const int64_t i_start, const int64_t chunk_size, const int64_t n) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
const int64_t i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
if(i < n && i < chunk_size)
|
||||
dst[i] = src[ii];
|
||||
}
|
||||
|
|
@ -104,13 +104,13 @@ __device__ __forceinline__ void binary_op_scalar(
|
|||
T r_args[][kILP],
|
||||
T** args,
|
||||
opmath_t scalar,
|
||||
int n,
|
||||
int chunk_size,
|
||||
bool all_aligned,
|
||||
const int64_t n,
|
||||
const int64_t chunk_size,
|
||||
const bool all_aligned,
|
||||
Op op) {
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
#pragma unroll
|
||||
|
|
@ -123,7 +123,7 @@ __device__ __forceinline__ void binary_op_scalar(
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
// Regardless if depth is 1 (for inplace) or 2 (for out of place), r_args has depth 1
|
||||
load_args<1>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
|
|
@ -141,13 +141,13 @@ __device__ __forceinline__ void pointwise_op_scalar(
|
|||
T r_args[][kILP],
|
||||
T** args,
|
||||
opmath_t scalar,
|
||||
int n,
|
||||
int chunk_size,
|
||||
bool all_aligned,
|
||||
const int64_t n,
|
||||
const int64_t chunk_size,
|
||||
const bool all_aligned,
|
||||
Op op) {
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
|
|
@ -163,7 +163,7 @@ __device__ __forceinline__ void pointwise_op_scalar(
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
// Regardless if depth is 3 (for inplace) or 4 (for out of place), r_args has depth 3
|
||||
load_args<3>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
|
|
@ -188,12 +188,12 @@ struct BinaryOpScalarFunctor {
|
|||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
|
|
@ -208,12 +208,12 @@ struct BinaryOpScalarListFunctor {
|
|||
int chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
|
@ -230,18 +230,18 @@ struct BinaryOpListAlphaFunctor {
|
|||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t alpha) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
|
|
@ -255,7 +255,7 @@ struct BinaryOpListAlphaFunctor {
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
|
|
@ -277,18 +277,18 @@ struct ZeroFunctor {
|
|||
__device__ __forceinline__ void operator() (
|
||||
int chunk_size,
|
||||
TensorListMetadata<1>& tl) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const auto all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] = 0;
|
||||
|
|
@ -298,7 +298,7 @@ struct ZeroFunctor {
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
r_args[0][ii] = 0;
|
||||
|
|
@ -316,9 +316,9 @@ struct UnaryOpFunctor {
|
|||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
|
|
@ -327,7 +327,7 @@ struct UnaryOpFunctor {
|
|||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
#pragma unroll
|
||||
|
|
@ -339,7 +339,7 @@ struct UnaryOpFunctor {
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
|
|
@ -363,12 +363,12 @@ struct PointwiseOpScalarFunctor {
|
|||
TensorListMetadata<depth>& tl,
|
||||
Op op,
|
||||
opmath_t scalar) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
|
|
@ -383,12 +383,12 @@ struct PointwiseOpScalarListFunctor {
|
|||
int chunk_size,
|
||||
TensorListScalarListMetadata<opmath_t, depth>& tl,
|
||||
Op op) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
opmath_t scalar = tl.scalar_vals[tensor_loc];
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
|
@ -404,18 +404,18 @@ struct PointwiseOpListFunctor {
|
|||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[depth - 1][kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if(n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
|
|
@ -429,7 +429,7 @@ struct PointwiseOpListFunctor {
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<depth - 1>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
|
|
@ -452,9 +452,9 @@ struct TernaryOpListFunctor {
|
|||
static_assert(depth == 3 || depth == 4, "");
|
||||
static_assert(depth >= r_args_depth, "");
|
||||
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
|
|
@ -462,7 +462,7 @@ struct TernaryOpListFunctor {
|
|||
T r_args[r_args_depth][kILP];
|
||||
|
||||
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for (int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
load_store(r_args[2], args[2], 0, i_start);
|
||||
|
|
@ -477,7 +477,7 @@ struct TernaryOpListFunctor {
|
|||
load_store(args[res_arg_index], r_args[0], i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
|
|
@ -504,18 +504,18 @@ struct TernaryOpScalarFunctor {
|
|||
static_assert(depth == 2 || depth == 3, "");
|
||||
static_assert(depth >= r_args_depth, "");
|
||||
static_assert(res_arg_index == depth - 1 || res_arg_index == 0, "");
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* args[depth];
|
||||
bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
const bool all_aligned = init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc);
|
||||
n -= chunk_idx * chunk_size;
|
||||
T r_args[r_args_depth][kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for(int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_args[0], args[0], 0, i_start);
|
||||
load_store(r_args[1], args[1], 0, i_start);
|
||||
|
|
@ -532,7 +532,7 @@ struct TernaryOpScalarFunctor {
|
|||
}
|
||||
}
|
||||
else {
|
||||
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for(int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<r_args_depth>(r_args, args, i_start, chunk_size, n);
|
||||
#pragma unroll
|
||||
for(int ii = 0; ii < kILP; ii++) {
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ struct LpNormFunctor {
|
|||
opmath_t* output_per_tensor,
|
||||
const int max_chunks_per_tensor
|
||||
) {
|
||||
int tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
int chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
int n = tl.numel_for_tensor[tensor_loc];
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
T* x = (T*)tl.addresses[0][tensor_loc];
|
||||
x += chunk_idx * chunk_size;
|
||||
|
|
@ -48,7 +48,7 @@ struct LpNormFunctor {
|
|||
}
|
||||
|
||||
if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) {
|
||||
for (int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(r_x, x, 0, i_start);
|
||||
#pragma unroll
|
||||
|
|
@ -58,7 +58,7 @@ struct LpNormFunctor {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
int i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
|
|
@ -110,7 +110,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(TensorList tensors, const Scalar& o
|
|||
} else if (ord.isFloatingPoint()) {
|
||||
p = ord.to<double>();
|
||||
} else {
|
||||
AT_ERROR("foreach_tensor_norm_cuda expects ord to be integer or float");
|
||||
TORCH_CHECK(false, "foreach_tensor_norm_cuda expects ord to be integer or float");
|
||||
}
|
||||
check_foreach_api_restrictions(tensors);
|
||||
const bool has_int_or_complex = std::any_of(tensors.begin(), tensors.end(), [](const auto & t) {
|
||||
|
|
@ -174,7 +174,7 @@ std::vector<Tensor> foreach_tensor_norm_cuda(TensorList tensors, const Scalar& o
|
|||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
} else {
|
||||
AT_ERROR("foreach_tensor_norm_cuda fast path got unexpected ord value: ", p);
|
||||
TORCH_CHECK(false, "foreach_tensor_norm_cuda fast path got unexpected ord value: ", p);
|
||||
}
|
||||
|
||||
std::vector<Tensor> result;
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ static constexpr int64_t kBlockSize = 512;
|
|||
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
|
||||
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
|
||||
static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = {72, 60};
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ bool is_aligned(T* p){
|
||||
|
|
@ -26,14 +27,14 @@ __device__ __forceinline__ bool is_aligned(T* p){
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
|
||||
__device__ __forceinline__ void load_store(T* dst, T* src, int64_t dst_offset, int64_t src_offset){
|
||||
using LT = at::native::memory::aligned_vector<T, kILP>;
|
||||
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
|
||||
}
|
||||
|
||||
template<int n> struct TensorListMetadata {
|
||||
const void* addresses[n][depth_to_max_tensors[n-1]];
|
||||
int numel_for_tensor[depth_to_max_tensors[n-1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors[n-1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n-1]];
|
||||
int start_tensor_this_launch;
|
||||
|
|
@ -41,28 +42,34 @@ template<int n> struct TensorListMetadata {
|
|||
|
||||
template<typename scalar_vals_t, int n> struct TensorListScalarListMetadata {
|
||||
const void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
|
||||
int numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n-1]];
|
||||
scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n-1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n-1]];
|
||||
};
|
||||
|
||||
// note(mkozuki): `n` of 96 and `scalar_vals_t` of `c10::complex<double>`
|
||||
// violates the cuda kernel argument size limitation of 4kb.
|
||||
// 80 is a number that does not violate this limitation.
|
||||
// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of 4kb with `c10::complex<double>`
|
||||
template<> struct TensorListScalarListMetadata<c10::complex<double>, 1> {
|
||||
const void* addresses[1][80];
|
||||
int numel_for_tensor[80];
|
||||
c10::complex<double> scalar_vals[80];
|
||||
const void* addresses[1][depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
c10::complex<double> scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[1-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[1-1]];
|
||||
};
|
||||
|
||||
template<> struct TensorListScalarListMetadata<c10::complex<double>, 2> {
|
||||
const void* addresses[2][depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
c10::complex<double> scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[2-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[2-1]];
|
||||
};
|
||||
|
||||
// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
|
||||
// whose each element is `at::Tensor` of 1 element representing the number of `step`s called so far.
|
||||
template<int n> struct FusedOptimizerTensorListMetadata {
|
||||
const void* addresses[n][depth_to_max_tensors[n-1]];
|
||||
int numel_for_tensor[depth_to_max_tensors[n-1]];
|
||||
int64_t numel_for_tensor[depth_to_max_tensors[n-1]];
|
||||
const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n-1]];
|
||||
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
|
||||
int block_to_chunk[depth_to_max_blocks[n-1]];
|
||||
|
|
@ -118,8 +125,8 @@ void multi_tensor_apply(
|
|||
if (n_zero_tensors == n_tensors) {
|
||||
continue;
|
||||
}
|
||||
const int chunks = (tensor_lists[0][t - static_cast<size_t>((t == n_tensors - 1) && (tensor_lists[0][t].numel() == 0))].numel() + kChunkSize - 1)/kChunkSize;
|
||||
for (int chunk = 0; chunk < chunks; chunk++) {
|
||||
const auto chunks = (tensor_lists[0][t - static_cast<size_t>((t == n_tensors - 1) && (tensor_lists[0][t].numel() == 0))].numel() + kChunkSize - 1)/kChunkSize;
|
||||
for (auto chunk = 0; chunk < chunks; chunk++) {
|
||||
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
|
@ -185,8 +192,8 @@ void multi_tensor_apply(
|
|||
if (n_zero_tensors == n_tensors) {
|
||||
continue;
|
||||
}
|
||||
const int chunks = (tensor_lists[0][t - static_cast<size_t>((t == n_tensors - 1) && (tensor_lists[0][t].numel() == 0))].numel() + kChunkSize - 1)/kChunkSize;
|
||||
for (int chunk = 0; chunk < chunks; chunk++) {
|
||||
const auto chunks = (tensor_lists[0][t - static_cast<size_t>((t == n_tensors - 1) && (tensor_lists[0][t].numel() == 0))].numel() + kChunkSize - 1)/kChunkSize;
|
||||
for (auto chunk = 0; chunk < chunks; chunk++) {
|
||||
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
||||
loc_block_info++;
|
||||
|
|
@ -253,7 +260,8 @@ void multi_tensor_apply_for_fused_optimizer(
|
|||
if (static_cast<decltype(num_tensors)>(num_zero_tensors) == num_tensors) {
|
||||
continue;
|
||||
}
|
||||
const auto chunks = (tensor_lists[0][tensor_index - static_cast<decltype(tensor_index)>((tensor_index == num_tensors - 1) && (tensor_lists[0][tensor_index].numel() == 0))].numel() + kChunkSize - 1) / kChunkSize;
|
||||
const int64_t chunks = (tensor_lists[0][tensor_index - static_cast<decltype(tensor_index)>((tensor_index == num_tensors - 1) && (tensor_lists[0][tensor_index].numel() == 0))].numel() + kChunkSize - 1) / kChunkSize;
|
||||
TORCH_CHECK(chunks > -1);
|
||||
for (const auto & chunk : c10::irange(chunks)) {
|
||||
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
|
||||
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ struct FusedAdamMathFunctor {
|
|||
scalar_type r_args[depth][kILP];
|
||||
|
||||
if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) {
|
||||
for (int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
for (int64_t i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < depth; i++) {
|
||||
load_store(r_args[i], args[i], 0, i_start);
|
||||
|
|
@ -152,7 +152,7 @@ struct FusedAdamMathFunctor {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
|
||||
load_args<depth>(r_args, args, i_start, chunk_size, n);
|
||||
adam_math<scalar_type, opmath_t, depth>(
|
||||
r_args, step_count, lr, beta1, beta2, weight_decay, eps, maximize, amsgrad, grad_scale_ptr, found_inf_ptr, adam_mode);
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from torch.testing._internal.common_utils import (
|
|||
skipIfTorchDynamo
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
|
||||
from torch.testing._internal.common_device_type import largeTensorTest
|
||||
from typing import Dict, Any, Tuple
|
||||
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
|
||||
from unittest.mock import patch
|
||||
|
|
@ -760,8 +761,9 @@ class TestOptim(TestCase):
|
|||
for k in st_p_state:
|
||||
self.assertEqual(st_p_state[k], mt_p_state[k])
|
||||
|
||||
def test_multi_tensor_optimizers(self):
|
||||
optimizer_pairs_with_flags = [
|
||||
@property
|
||||
def _multi_tensor_optimizer_configs(self):
|
||||
return [
|
||||
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, fused=False)),
|
||||
(optim.Adam, dict(weight_decay=1.0, amsgrad=False, fused=False)),
|
||||
(optim.Adam, dict(weight_decay=0.0, amsgrad=True, fused=False)),
|
||||
|
|
@ -800,53 +802,29 @@ class TestOptim(TestCase):
|
|||
(optim.Adagrad, dict(weight_decay=0)),
|
||||
(optim.Adagrad, dict(weight_decay=1)),
|
||||
]
|
||||
self._test_derived_optimizers(optimizer_pairs_with_flags, "foreach")
|
||||
|
||||
def test_multi_tensor_optimizers(self):
|
||||
self._test_derived_optimizers(self._multi_tensor_optimizer_configs, "foreach")
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_multi_tensor_optimizers_with_varying_tensors(self):
|
||||
optimizer_pairs_with_flags = [
|
||||
(optim.Adam, dict(weight_decay=1.0, amsgrad=True, fused=False)),
|
||||
(optim.Adam, dict(weight_decay=1.0, amsgrad=False, fused=False)),
|
||||
(optim.Adam, dict(weight_decay=0.0, amsgrad=True, fused=False)),
|
||||
(optim.Adam, dict(weight_decay=0.0, amsgrad=False, fused=False)),
|
||||
(optim.AdamW, dict(weight_decay=1.0, amsgrad=True)),
|
||||
(optim.AdamW, dict(weight_decay=1.0, amsgrad=False)),
|
||||
(optim.AdamW, dict(weight_decay=0.0, amsgrad=True)),
|
||||
(optim.AdamW, dict(weight_decay=0.0, amsgrad=False)),
|
||||
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=6e-3)),
|
||||
(optim.NAdam, dict(weight_decay=1.0, momentum_decay=6e-3)),
|
||||
(optim.NAdam, dict(weight_decay=0.0, momentum_decay=4e-3)),
|
||||
(optim.NAdam, dict(weight_decay=0.01, momentum_decay=4e-3)),
|
||||
(
|
||||
optim.SGD,
|
||||
dict(lr=0.2, momentum=1, dampening=0, weight_decay=1, nesterov=True),
|
||||
),
|
||||
(
|
||||
optim.SGD,
|
||||
dict(lr=0.2, momentum=1, dampening=0.5, weight_decay=1, nesterov=False),
|
||||
),
|
||||
(optim.RAdam, dict(weight_decay=0, eps=1e-6)),
|
||||
(optim.RAdam, dict(weight_decay=0)),
|
||||
(optim.RAdam, dict(weight_decay=1, eps=1e-6)),
|
||||
(optim.RAdam, dict(weight_decay=1)),
|
||||
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=True)),
|
||||
(optim.RMSprop, dict(weight_decay=1, momentum=0, centered=True)),
|
||||
(optim.RMSprop, dict(weight_decay=1, momentum=1, centered=False)),
|
||||
(optim.RMSprop, dict(weight_decay=0, momentum=1, centered=False)),
|
||||
(optim.Rprop, dict(lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))),
|
||||
(optim.ASGD, dict(weight_decay=0)),
|
||||
(optim.ASGD, dict(weight_decay=1)),
|
||||
(optim.Adamax, dict(weight_decay=0)),
|
||||
(optim.Adamax, dict(weight_decay=1)),
|
||||
(optim.Adadelta, dict(weight_decay=0)),
|
||||
(optim.Adadelta, dict(weight_decay=1)),
|
||||
(optim.Adagrad, dict(weight_decay=0)),
|
||||
(optim.Adagrad, dict(weight_decay=1)),
|
||||
]
|
||||
self._test_derived_optimizers_varying_tensors(optimizer_pairs_with_flags, "foreach")
|
||||
self._test_derived_optimizers_varying_tensors(self._multi_tensor_optimizer_configs, "foreach")
|
||||
|
||||
def test_fused_optimizers(self):
|
||||
optimizer_pairs_with_flags = tuple(itertools.product(
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
|
||||
@largeTensorTest("72GB", "cuda")
|
||||
def test_multi_tensor_optimizers_with_large_tensors(self):
|
||||
for optimizer_ctor, optimizer_params in self._multi_tensor_optimizer_configs:
|
||||
# note(crcrpar): H100 wasn't sufficient for Adamax, surprisingly
|
||||
if optimizer_ctor == optim.Adamax:
|
||||
continue
|
||||
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
|
||||
params[0].grad = torch.zeros_like(params[0])
|
||||
optimizer = optimizer_ctor(params, foreach=True, **optimizer_params)
|
||||
optimizer.step()
|
||||
|
||||
@property
|
||||
def _fused_optimizer_configs(self):
|
||||
return tuple(itertools.product(
|
||||
(optim.Adam, optim.AdamW),
|
||||
(
|
||||
dict(weight_decay=1., amsgrad=False),
|
||||
|
|
@ -855,20 +833,22 @@ class TestOptim(TestCase):
|
|||
dict(weight_decay=0., amsgrad=True),
|
||||
),
|
||||
))
|
||||
self._test_derived_optimizers(optimizer_pairs_with_flags, "fused")
|
||||
|
||||
def test_fused_optimizers(self):
|
||||
self._test_derived_optimizers(self._fused_optimizer_configs, "fused")
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_fused_optimizers_with_varying_tensors(self):
|
||||
optimizer_pairs_with_flags = tuple(itertools.product(
|
||||
(optim.Adam, optim.AdamW),
|
||||
(
|
||||
dict(weight_decay=1., amsgrad=False),
|
||||
dict(weight_decay=1., amsgrad=True),
|
||||
dict(weight_decay=0., amsgrad=False),
|
||||
dict(weight_decay=0., amsgrad=True),
|
||||
),
|
||||
))
|
||||
self._test_derived_optimizers_varying_tensors(optimizer_pairs_with_flags, "fused")
|
||||
self._test_derived_optimizers_varying_tensors(self._fused_optimizer_configs, "fused")
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Requires a GPU")
|
||||
@largeTensorTest("64GB", "cuda")
|
||||
def test_fused_optimizers_with_large_tensors(self):
|
||||
for optimizer_ctor, optimizer_params in self._fused_optimizer_configs:
|
||||
params = [torch.ones(2 ** 32, device="cuda", dtype=torch.float16)]
|
||||
params[0].grad = torch.zeros_like(params[0])
|
||||
optimizer = optimizer_ctor(params, fused=True, **optimizer_params)
|
||||
optimizer.step()
|
||||
|
||||
def test_adam(self):
|
||||
self._test_basic_cases(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user