mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "_foreach_copy with different src/dst dtypes (#121717)"
This reverts commit da2a9a0512.
Reverted https://github.com/pytorch/pytorch/pull/121717 on behalf of https://github.com/janeyx99 due to Causing IMAs on V100s internally :C ([comment](https://github.com/pytorch/pytorch/pull/121717#issuecomment-2025553295))
This commit is contained in:
parent
8698121636
commit
958dbb876c
|
|
@ -102,13 +102,12 @@ inline void check_foreach_api_restrictions(
|
|||
// corresponding tensors (aligning in index across the tensorLists) share the
|
||||
// same device and dtype.
|
||||
inline bool _check_tensors_share_device_and_dtype(
|
||||
ArrayRef<TensorList> tensorLists,
|
||||
const bool skip_dtype_check = false) {
|
||||
ArrayRef<TensorList> tensorLists) {
|
||||
const auto expected_dtype = tensorLists[0][0].dtype();
|
||||
const auto expected_device = tensorLists[0][0].device();
|
||||
|
||||
auto is_tensor_okay = [&](const Tensor& tensor) {
|
||||
return (skip_dtype_check || tensor.dtype() == expected_dtype) &&
|
||||
return tensor.dtype() == expected_dtype &&
|
||||
tensor.device() == expected_device && tensor.layout() == at::kStrided &&
|
||||
tensor.is_non_overlapping_and_dense();
|
||||
};
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
#include <ATen/native/cuda/ForeachFunctors.cuh>
|
||||
#include <ATen/native/cuda/ForeachMinMaxFunctors.cuh>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
|
@ -251,152 +250,20 @@ FOREACH_BINARY_OP_LIST(
|
|||
power_functor,
|
||||
/*division_op*/ true);
|
||||
|
||||
template <typename dst_t, typename src_t = dst_t>
|
||||
struct Copy {
|
||||
__device__ __forceinline__ dst_t operator()(const src_t& x) {
|
||||
return static_cast<dst_t>(x);
|
||||
template <typename T>
|
||||
struct Identity {
|
||||
__device__ __forceinline__ T operator()(const T& x) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t>
|
||||
struct Copy<dst_t, c10::complex<double>> {
|
||||
__device__ __forceinline__ dst_t operator()(const c10::complex<double>& x) {
|
||||
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
|
||||
std::is_same_v<dst_t, c10::complex<float>>)) {
|
||||
return static_cast<dst_t>(x.real());
|
||||
} else {
|
||||
return static_cast<dst_t>(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dst_t>
|
||||
struct Copy<dst_t, c10::complex<float>> {
|
||||
__device__ __forceinline__ dst_t operator()(const c10::complex<float>& x) {
|
||||
if constexpr (!(std::is_same_v<dst_t, c10::complex<double>> ||
|
||||
std::is_same_v<dst_t, c10::complex<float>>)) {
|
||||
return static_cast<dst_t>(x.real());
|
||||
} else {
|
||||
return static_cast<dst_t>(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define AT_DISPATCH_SOURCE_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Byte, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Char, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Long, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Short, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Double, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Float, src_t, __VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::ComplexDouble, \
|
||||
src_t, \
|
||||
__VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::ComplexFloat, \
|
||||
src_t, \
|
||||
__VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Half, \
|
||||
src_t, \
|
||||
__VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::BFloat16, \
|
||||
src_t, \
|
||||
__VA_ARGS__) \
|
||||
AT_PRIVATE_CASE_TYPE_USING_HINT( \
|
||||
at::ScalarType::Bool, \
|
||||
src_t, \
|
||||
__VA_ARGS__))
|
||||
|
||||
namespace {
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename src_t,
|
||||
int depth,
|
||||
int r_args_depth,
|
||||
int res_arg_index>
|
||||
struct CopyFunctor {
|
||||
static_assert(depth == 2 && r_args_depth == 1 && res_arg_index == 1);
|
||||
template <typename Op>
|
||||
__device__ __forceinline__ void operator()(
|
||||
int chunk_size,
|
||||
TensorListMetadata<depth>& tl,
|
||||
Op op) {
|
||||
const auto tensor_loc = tl.block_to_tensor[blockIdx.x];
|
||||
const auto chunk_idx = tl.block_to_chunk[blockIdx.x];
|
||||
auto n = tl.numel_for_tensor[tensor_loc];
|
||||
|
||||
src_t* src_ptr = (src_t*)tl.addresses[0][tensor_loc];
|
||||
src_ptr += chunk_idx * chunk_size;
|
||||
T* self_ptr = (T*)tl.addresses[1][tensor_loc];
|
||||
self_ptr += chunk_idx * chunk_size;
|
||||
|
||||
const bool all_aligned{is_aligned(src_ptr) && is_aligned(self_ptr)};
|
||||
|
||||
n -= chunk_idx * chunk_size;
|
||||
src_t src_args[kILP];
|
||||
T r_args[kILP];
|
||||
|
||||
// to make things simple, we put aligned case in a different code path
|
||||
if (n % kILP == 0 && chunk_size % kILP == 0 && all_aligned) {
|
||||
for (int64_t i_start = threadIdx.x;
|
||||
i_start * kILP < n && i_start * kILP < chunk_size;
|
||||
i_start += blockDim.x) {
|
||||
// load
|
||||
load_store(src_args, src_ptr, 0, i_start);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[ii] = static_cast<T>(op(src_args[ii]));
|
||||
}
|
||||
// store
|
||||
load_store(self_ptr, r_args, i_start, 0);
|
||||
}
|
||||
} else {
|
||||
for (int64_t i_start = 0; i_start < n && i_start < chunk_size;
|
||||
i_start += blockDim.x * kILP) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
const auto i = i_start + threadIdx.x + ii * blockDim.x;
|
||||
src_args[ii] = src_ptr[i];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < kILP; ii++) {
|
||||
r_args[ii] = static_cast<T>(op(src_args[ii]));
|
||||
}
|
||||
store_args(self_ptr, r_args, i_start, chunk_size, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void foreach_tensor_copy_list_kernel_cuda_(
|
||||
TensorList self,
|
||||
TensorList src,
|
||||
const bool non_blocking) {
|
||||
check_foreach_api_restrictions(self, src);
|
||||
if (!(_check_tensors_share_device_and_dtype(
|
||||
{self, src}, /* skip_dtype_check */ true) &&
|
||||
std::all_of(
|
||||
src.cbegin(),
|
||||
src.cend(),
|
||||
[&](const auto& t) -> bool {
|
||||
return t.dtype() == src[0].dtype();
|
||||
}) &&
|
||||
_check_tensors_share_sizes_and_strides({self, src}))) {
|
||||
if (!can_use_fast_route(
|
||||
self, src, /* does_op_promote_integer_inputs_to_float */ false)) {
|
||||
return at::native::foreach_tensor_copy_list_kernel_slow_(
|
||||
self, src, non_blocking);
|
||||
}
|
||||
|
|
@ -411,8 +278,6 @@ void foreach_tensor_copy_list_kernel_cuda_(
|
|||
"foreach_tensor_copy",
|
||||
[&]() {
|
||||
using opmath_t = at::opmath_type<scalar_t>;
|
||||
AT_DISPATCH_SOURCE_TYPES(src[0].scalar_type(), "foreach_tensor_copy", [&] {
|
||||
if constexpr (std::is_same_v<scalar_t, src_t>) {
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
UnaryOpFunctor<
|
||||
|
|
@ -420,29 +285,9 @@ void foreach_tensor_copy_list_kernel_cuda_(
|
|||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Copy<opmath_t, opmath_t>());
|
||||
} else {
|
||||
// Ref:
|
||||
// https://github.com/pytorch/pytorch/blob/656134c38f4737d13c3f43fc5c59470bc23c1d2f/aten/src/ATen/native/Copy.cpp#L299-L301
|
||||
if (!self[0].is_complex() && src[0].is_complex()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Casting complex values to real discards the imaginary part");
|
||||
}
|
||||
multi_tensor_apply<2>(
|
||||
tensor_lists,
|
||||
CopyFunctor<
|
||||
scalar_t,
|
||||
src_t,
|
||||
/* depth */ 2,
|
||||
/* r_args_depth */ 1,
|
||||
/* res_arg_index */ 1>(),
|
||||
Copy<scalar_t, src_t>());
|
||||
}
|
||||
});
|
||||
Identity<opmath_t>());
|
||||
});
|
||||
increment_version(self);
|
||||
}
|
||||
|
||||
#undef AT_DISPATCH_SOURCE_TYPES
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -838,20 +838,6 @@ class TestForeach(TestCase):
|
|||
copy_(t, s, non_blocking)
|
||||
self.assertEqual(ref_input, sample.input)
|
||||
|
||||
@onlyCUDA
|
||||
@ops(filter(lambda op: op.name == "_foreach_copy", foreach_binary_op_db))
|
||||
def test_foreach_copy_with_multi_dtypes(self, device, dtype, op):
|
||||
# check (a) multi_tensor_apply is called and (b) numerical parity with for-loop and Tensor.copy_
|
||||
foreach_copy_ = ForeachFuncWrapper(op.inplace_variant)
|
||||
for sample in op.sample_inputs(device, dtype, noncontiguous=False):
|
||||
for src_dtype in floating_types_and(torch.half, torch.bfloat16):
|
||||
if src_dtype == dtype:
|
||||
continue
|
||||
self_tensors = [t.clone() for t in sample.input]
|
||||
src_tensors = [t.to(src_dtype) for t in self_tensors]
|
||||
out = foreach_copy_((self_tensors, src_tensors), is_cuda=True, expect_fastpath=True)
|
||||
self.assertEqual(out, [torch.empty_like(t).copy_(s) for t, s in zip(self_tensors, src_tensors)])
|
||||
|
||||
# Test reverse-mode & forward-mode AD if supported.
|
||||
@onlyCUDA
|
||||
@ops(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user