mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow int vals to go down the fastpath for _foreach_max (#127303)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127303 Approved by: https://github.com/albanD ghstack dependencies: #127187
This commit is contained in:
parent
601c5e085d
commit
05e99154ee
|
|
@ -68,8 +68,8 @@ struct LpMaxFunctor {
|
||||||
T vals[kILP];
|
T vals[kILP];
|
||||||
T r_x[kILP];
|
T r_x[kILP];
|
||||||
for (int64_t i = 0; i < kILP; i++) {
|
for (int64_t i = 0; i < kILP; i++) {
|
||||||
vals[i] = T(-INFINITY);
|
vals[i] = T(std::numeric_limits<T>::lowest());
|
||||||
r_x[i] = T(-INFINITY);
|
r_x[i] = T(std::numeric_limits<T>::lowest());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) {
|
if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) {
|
||||||
|
|
@ -96,7 +96,7 @@ struct LpMaxFunctor {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto val = T(-INFINITY);
|
auto val = T(std::numeric_limits<T>::lowest());
|
||||||
for (int i = 0; i < kILP; i++) {
|
for (int i = 0; i < kILP; i++) {
|
||||||
val = max_propagate_nan(val, vals[i]);
|
val = max_propagate_nan(val, vals[i]);
|
||||||
}
|
}
|
||||||
|
|
@ -118,7 +118,7 @@ __global__ void lpmax_cleanup(
|
||||||
__shared__ T vals[512];
|
__shared__ T vals[512];
|
||||||
const T* output_this_tensor =
|
const T* output_this_tensor =
|
||||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||||
T val = -INFINITY;
|
T val = std::numeric_limits<T>::lowest();
|
||||||
for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
|
for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
|
||||||
val = max_propagate_nan(val, output_this_tensor[i]);
|
val = max_propagate_nan(val, output_this_tensor[i]);
|
||||||
}
|
}
|
||||||
|
|
@ -130,21 +130,11 @@ __global__ void lpmax_cleanup(
|
||||||
|
|
||||||
std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
|
std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
|
||||||
check_foreach_api_restrictions(tensors);
|
check_foreach_api_restrictions(tensors);
|
||||||
// we currently use -INF as the identity value to compare against, which
|
if (!can_use_fast_route(tensors)) {
|
||||||
// does not work for int8, int16, nor bool. Fall back to slow path here.
|
|
||||||
const bool has_small_int_or_bool =
|
|
||||||
std::any_of(tensors.begin(), tensors.end(), [](const auto& t) {
|
|
||||||
const auto scalar_type = t.scalar_type();
|
|
||||||
return scalar_type == at::ScalarType::Short ||
|
|
||||||
scalar_type == at::ScalarType::Char ||
|
|
||||||
scalar_type == at::ScalarType::Bool;
|
|
||||||
});
|
|
||||||
if (!can_use_fast_route(tensors) || has_small_int_or_bool) {
|
|
||||||
return foreach_tensor_max_slow(tensors);
|
return foreach_tensor_max_slow(tensors);
|
||||||
}
|
}
|
||||||
|
|
||||||
// for parity with max in ReduceAllOps.cpp, though I think max(empty) should
|
// for parity with max in ReduceAllOps.cpp, as max(empty) is ???
|
||||||
// eventually be allowed.
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
std::all_of(
|
std::all_of(
|
||||||
tensors.begin(),
|
tensors.begin(),
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) {
|
||||||
shared[wid] = val;
|
shared[wid] = val;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY);
|
val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits<T>::lowest());
|
||||||
if (wid == 0) {
|
if (wid == 0) {
|
||||||
val = WarpReduceMax(val);
|
val = WarpReduceMax(val);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1015,7 +1015,7 @@ class TestForeach(TestCase):
|
||||||
def test_foreach_reduce_large_input(self, device, dtype, op):
|
def test_foreach_reduce_large_input(self, device, dtype, op):
|
||||||
# test inputs larger than kChunkSize = 65536
|
# test inputs larger than kChunkSize = 65536
|
||||||
N = 65536 * 2
|
N = 65536 * 2
|
||||||
disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool)
|
disable_fastpath = False
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if op.name == "_foreach_norm":
|
if op.name == "_foreach_norm":
|
||||||
ord = 2
|
ord = 2
|
||||||
|
|
|
||||||
|
|
@ -9380,7 +9380,7 @@ class foreach_max_sample_func(foreach_inputs_sample_func):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
||||||
return dtype in (torch.int8, torch.int16, torch.bool)
|
return False
|
||||||
|
|
||||||
|
|
||||||
class foreach_norm_sample_func(foreach_inputs_sample_func):
|
class foreach_norm_sample_func(foreach_inputs_sample_func):
|
||||||
|
|
@ -11125,6 +11125,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||||
supports_inplace_autograd=True,
|
supports_inplace_autograd=True,
|
||||||
supports_forward_ad=True,
|
supports_forward_ad=True,
|
||||||
decorators=(
|
decorators=(
|
||||||
|
# no complex support for ordering ops like max
|
||||||
DecorateInfo(
|
DecorateInfo(
|
||||||
unittest.expectedFailure,
|
unittest.expectedFailure,
|
||||||
"TestForeach",
|
"TestForeach",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user