mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Avoid COW materialization in at::parallel_for/parallel_reduce (#120455)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120455 Approved by: https://github.com/albanD
This commit is contained in:
parent
d053dcfa69
commit
13a54ce279
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ParallelGuard.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
|
||||
namespace at {
|
||||
|
|
@ -24,13 +25,19 @@ inline void parallel_for(
|
|||
at::get_num_threads() > 1);
|
||||
if (!use_parallel) {
|
||||
internal::ThreadIdGuard tid_guard(0);
|
||||
c10::ParallelGuard guard(true);
|
||||
f(begin, end);
|
||||
return;
|
||||
}
|
||||
|
||||
internal::invoke_parallel(begin, end, grain_size, f);
|
||||
internal::invoke_parallel(
|
||||
begin, end, grain_size, [&](int64_t begin, int64_t end) {
|
||||
c10::ParallelGuard guard(true);
|
||||
f(begin, end);
|
||||
});
|
||||
#else
|
||||
internal::ThreadIdGuard tid_guard(0);
|
||||
c10::ParallelGuard guard(true);
|
||||
f(begin, end);
|
||||
#endif
|
||||
}
|
||||
|
|
@ -56,6 +63,7 @@ inline scalar_t parallel_reduce(
|
|||
max_threads > 1);
|
||||
if (!use_parallel) {
|
||||
internal::ThreadIdGuard tid_guard(0);
|
||||
c10::ParallelGuard guard(true);
|
||||
return f(begin, end, ident);
|
||||
}
|
||||
|
||||
|
|
@ -66,6 +74,7 @@ inline scalar_t parallel_reduce(
|
|||
grain_size,
|
||||
[&](const int64_t my_begin, const int64_t my_end) {
|
||||
const auto tid = at::get_thread_num();
|
||||
c10::ParallelGuard guard(true);
|
||||
results[tid] = f(my_begin, my_end, ident);
|
||||
});
|
||||
|
||||
|
|
@ -76,6 +85,7 @@ inline scalar_t parallel_reduce(
|
|||
return result;
|
||||
#else
|
||||
internal::ThreadIdGuard tid_guard(0);
|
||||
c10::ParallelGuard guard(true);
|
||||
return f(begin, end, ident);
|
||||
#endif
|
||||
}
|
||||
|
|
|
|||
|
|
@ -430,12 +430,12 @@ void slow_conv3d_backward_out_cpu_template(
|
|||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_input", [&] {
|
||||
auto grad_input_a = grad_input.accessor<scalar_t, 5>();
|
||||
auto grad_output_a = grad_output_contiguous.accessor<scalar_t, 5>();
|
||||
auto fgrad_input_a = fgrad_input.accessor<scalar_t, 3>();
|
||||
auto weight_2d_a = weight2d.accessor<scalar_t, 2>();
|
||||
at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT,
|
||||
[&](int64_t start, int64_t end) {
|
||||
auto grad_input_a = grad_input.accessor<scalar_t, 5>();
|
||||
auto grad_output_a = grad_output_contiguous.accessor<scalar_t, 5>();
|
||||
auto fgrad_input_a = fgrad_input.accessor<scalar_t, 3>();
|
||||
auto weight_2d_a = weight2d.accessor<scalar_t, 2>();
|
||||
|
||||
for (const auto t : c10::irange(start, end)) {
|
||||
auto grad_input_t = grad_input_a[t];
|
||||
|
|
|
|||
|
|
@ -1835,6 +1835,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
|
|||
r, r, batch1.select(0, b), batch2.select(0, b), 0, 1);
|
||||
}
|
||||
};
|
||||
// Materialize if COW, since we cannot do so during parallel_for
|
||||
self_or_result.mutable_data_ptr();
|
||||
at::parallel_for(0, bs, 1, bmm_out_fn);
|
||||
} else {
|
||||
for (const auto b : c10::irange(bs)) {
|
||||
|
|
@ -1851,6 +1853,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens
|
|||
batch1.select(0, b), batch2.select(0, b), beta, alpha);
|
||||
}
|
||||
};
|
||||
// Materialize if COW, since we cannot do so during parallel_for
|
||||
self_or_result.mutable_data_ptr();
|
||||
at::parallel_for(0, bs, 1, bmm_fn);
|
||||
} else {
|
||||
for (const auto b : c10::irange(bs)) {
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_
|
|||
auto gp = grad.permute({1,0,2});
|
||||
auto grad_a_global = gp.accessor<scalar_t, 3>();
|
||||
auto targets_data = targets.data_ptr<target_t>();
|
||||
auto grad_out_a = grad_out.accessor<scalar_t, 1>();
|
||||
|
||||
auto create_fill_iterator = [](const Tensor& tensor, IntArrayRef squash_dims) {
|
||||
return TensorIteratorConfig()
|
||||
|
|
@ -366,7 +367,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_
|
|||
// now we wrap up the calculation by adding in the remaining items of eq (16)
|
||||
// this could be a great target for further vectorization.
|
||||
// grad is the output gradient, nll is the loss. Note that the likelihood -nll is the Z of eq (16)
|
||||
scalar_t gr = grad_out.accessor<scalar_t, 1>()[b];
|
||||
scalar_t gr = grad_out_a[b];
|
||||
for (const auto t : c10::irange(input_length)) { // or go for the full thing?
|
||||
for (const auto c : c10::irange(num_labels)) {
|
||||
scalar_t& res = grad_a[t][c];
|
||||
|
|
|
|||
|
|
@ -298,6 +298,9 @@ void slow_conv_transpose2d_out_cpu_template(
|
|||
}
|
||||
columns.zero_();
|
||||
|
||||
// Materialize if COW, since we cannot do so during parallel_for
|
||||
output.mutable_data_ptr();
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16,
|
||||
at::ScalarType::Half, input.scalar_type(), "slow_conv_transpose2d_out_cpu", [&] {
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,9 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
|
|||
auto& pack_b =
|
||||
cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
|
||||
|
||||
int32_t* col_offsets_data = col_offsets.data_ptr<int32_t>();
|
||||
float* bias_contig_data = bias_contig.data_ptr<float>();
|
||||
|
||||
const int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
// This operation does the following:
|
||||
|
|
@ -162,8 +165,8 @@ Tensor fbgemm_linear_int8_weight_fp32_activation(
|
|||
/*Aq_zero_point=*/q_params.zero_point,
|
||||
/*Bq_zero_point=*/&weight_zero_point_int32,
|
||||
/*row_offsets=*/pack_a.getRowOffsetBuffer(),
|
||||
/*col_offsets=*/col_offsets.data_ptr<int32_t>(),
|
||||
/*bias=*/bias_contig.data_ptr<float>(),
|
||||
/*col_offsets=*/col_offsets_data,
|
||||
/*bias=*/bias_contig_data,
|
||||
/*nCol=*/N);
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
|
|
|
|||
|
|
@ -2293,6 +2293,8 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
|
||||
auto out_accessor = result.accessor<int64_t, 2>();
|
||||
|
||||
// Pass 2: Write indexes
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
|
||||
kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] {
|
||||
|
|
@ -2313,7 +2315,6 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) {
|
|||
}
|
||||
}
|
||||
|
||||
auto out_accessor = result.accessor<int64_t, 2>();
|
||||
auto out_ptr = out_accessor[thread_count_nonzero[tid]].data();
|
||||
|
||||
auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) {
|
||||
|
|
|
|||
|
|
@ -2066,7 +2066,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
|
|||
auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_src = src.data_ptr<int64_t>() + start;
|
||||
const auto* ptr_src = src.const_data_ptr<int64_t>() + start;
|
||||
|
||||
for (const auto i : c10::irange(start, end)) {
|
||||
const auto src_val = *ptr_src++;
|
||||
|
|
@ -2124,9 +2124,9 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
|
|||
const auto start = tid * chunk_size_src;
|
||||
const auto end = std::min(start + chunk_size_src, src_len);
|
||||
const auto tid_offset = thread_offsets.const_data_ptr<int64_t>()[tid];
|
||||
const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr<int64_t>();
|
||||
const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).const_data_ptr<int64_t>();
|
||||
const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).const_data_ptr<int64_t>();
|
||||
const auto* ptr_tid_int_counts = int_counts.select(0, tid).const_data_ptr<int64_t>();
|
||||
auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset;
|
||||
auto* ptr_tid_selected_src = ptr_selected_src + tid_offset;
|
||||
|
||||
|
|
@ -2324,9 +2324,9 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in
|
|||
const auto end = std::min(start + chunk_size, src_len);
|
||||
auto* ptr_src_tid = ptr_src + start;
|
||||
const auto* ptr_src_counts_per_thread
|
||||
= src_counts_per_thread.select(0, tid).data_ptr<int64_t>();
|
||||
= src_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
|
||||
const auto* ptr_src_offset_counts_per_thread
|
||||
= src_offset_counts_per_thread.select(0, tid).data_ptr<int64_t>();
|
||||
= src_offset_counts_per_thread.select(0, tid).const_data_ptr<int64_t>();
|
||||
auto tid_counts = at::zeros({size}, src.options());
|
||||
auto* ptr_tid_counts = tid_counts.data_ptr<int64_t>();
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/FunctionalInverses.h>
|
||||
#include <ATen/ScalarOps.h>
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
|
|
@ -13,6 +14,7 @@
|
|||
#include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
|
||||
#include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
|
||||
#include <ATen/ops/_test_check_tensor_native.h>
|
||||
#include <ATen/ops/_test_parallel_materialize_native.h>
|
||||
#include <ATen/ops/_test_optional_filled_intlist_native.h>
|
||||
#include <ATen/ops/_test_optional_floatlist_native.h>
|
||||
#include <ATen/ops/_test_optional_intlist_native.h>
|
||||
|
|
@ -111,6 +113,21 @@ Tensor _test_check_tensor(const Tensor& self) {
|
|||
return self.clone();
|
||||
}
|
||||
|
||||
Tensor _test_parallel_materialize(const Tensor& self, int64_t num_parallel, bool skip_first) {
|
||||
at::parallel_for(0, num_parallel, 1, [&](int64_t begin, int64_t end){
|
||||
// NOTE: skip_first is meant to avoid triggering the materialization from
|
||||
// the first thread, to ensure that the subthreads throw the error
|
||||
// correctly. On some platforms, the first thread is the main thread and it
|
||||
// begins executing the loop function much earlier than the subthreads.
|
||||
if (skip_first && begin == 0 && end == 1) {
|
||||
return;
|
||||
} else {
|
||||
self.mutable_data_ptr();
|
||||
}
|
||||
});
|
||||
return self;
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
||||
namespace at::functionalization {
|
||||
|
|
|
|||
|
|
@ -292,16 +292,21 @@ Tensor _convolution_depthwise3x3_winograd(
|
|||
bias_potentially_undefined :
|
||||
at::zeros({kernel_sizes[0]}, input.options());
|
||||
|
||||
auto input_data = input.data_ptr<float>();
|
||||
auto kernel_data = kernel.data_ptr<float>();
|
||||
auto bias_data = bias.data_ptr<float>();
|
||||
auto output_data = output.data_ptr<float>();
|
||||
|
||||
at::parallel_for(0, args.batch * args.out_channels, 0, [&](int64_t start, int64_t end) {
|
||||
for (const auto k : c10::irange(start, end)) {
|
||||
const int64_t g = k % args.out_channels;
|
||||
const int64_t i = k / (args.out_channels / groups);
|
||||
convolution_depthwise3x3_winograd_impl(
|
||||
args,
|
||||
input.data_ptr<float>() + i * input_hxw,
|
||||
kernel.data_ptr<float>() + g * 3 * 3,
|
||||
bias.data_ptr<float>() + g,
|
||||
output.data_ptr<float>() + k * output_hxw);
|
||||
input_data + i * input_hxw,
|
||||
kernel_data + g * 3 * 3,
|
||||
bias_data + g,
|
||||
output_data + k * output_hxw);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -14120,6 +14120,12 @@
|
|||
# It is undocumented and should not be used outside of tests.
|
||||
- func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor
|
||||
|
||||
# Note: for testing COW materialization within `at::parallel_for` loop function
|
||||
- func: _test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: _test_parallel_materialize
|
||||
|
||||
# Note: this function is only for testing.
|
||||
- func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -367,7 +367,7 @@ inline Tensor wrap_tensor_node(
|
|||
if (tensor_node.children(i).numel() > 0) {
|
||||
memcpy(
|
||||
nt_buffer.mutable_data_ptr<scalar_t>() + start_offsets[i],
|
||||
tensor_node.children(i).data_ptr<scalar_t>(),
|
||||
tensor_node.children(i).const_data_ptr<scalar_t>(),
|
||||
tensor_node.children(i).numel() * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -491,7 +491,7 @@ at::Tensor& embedding_bag_byte_impl(
|
|||
/*offsets_or_lengths=*/offsets_data + start_idx,
|
||||
/*weights=*/
|
||||
per_sample_weights_
|
||||
? per_sample_weights_.value().data_ptr<float>() +
|
||||
? per_sample_weights_.value().const_data_ptr<float>() +
|
||||
offsets_data[start_idx]
|
||||
: nullptr,
|
||||
/*out=*/output_data + start_idx * D);
|
||||
|
|
|
|||
|
|
@ -123,6 +123,8 @@ at::Tensor& PackedLinearWeight::apply_impl(
|
|||
// Allocate a buffer for fbgemmPacked to use
|
||||
auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt));
|
||||
|
||||
auto output_data = reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>());
|
||||
|
||||
int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto task_id : c10::irange(begin, end)) {
|
||||
|
|
@ -184,7 +186,7 @@ at::Tensor& PackedLinearWeight::apply_impl(
|
|||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C=*/output_data,
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
|
|
@ -220,7 +222,7 @@ at::Tensor& PackedLinearWeight::apply_impl(
|
|||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/reinterpret_cast<uint8_t*>(output.data_ptr<c10::quint8>()),
|
||||
/*C=*/output_data,
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
|
|
@ -358,6 +360,8 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl
|
|||
output.options().dtype(at::kInt),
|
||||
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
auto output_data = output.data_ptr<float>();
|
||||
|
||||
int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
fbgemm::PackAWithQuantRowOffset<uint8_t> packA(
|
||||
|
|
@ -396,7 +400,7 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl
|
|||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/output.data_ptr<float>(),
|
||||
/*C=*/output_data,
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
|
|
@ -428,7 +432,7 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl
|
|||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/packA,
|
||||
/*packB=*/*packB,
|
||||
/*C=*/output.data_ptr<float>(),
|
||||
/*C=*/output_data,
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/outputProcObj,
|
||||
|
|
|
|||
|
|
@ -423,6 +423,8 @@ at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl(
|
|||
// Resize output Tensor
|
||||
output.resize_(output_sizes);
|
||||
|
||||
auto output_data = output.data_ptr<float>();
|
||||
|
||||
int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
for (const auto task_id : c10::irange(begin, end)) {
|
||||
|
|
@ -433,7 +435,7 @@ at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl(
|
|||
/*A=*/input_ptr,
|
||||
/*Bp=*/packed_weight_fp16,
|
||||
/*beta=*/0.0f,
|
||||
/*C=*/output.data_ptr<float>(),
|
||||
/*C=*/output_data,
|
||||
/*thread_id=*/static_cast<int>(task_id),
|
||||
/*num_threads=*/num_tasks);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <c10/core/alignment.h>
|
||||
#include <c10/core/impl/COWDeleter.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/ParallelGuard.h>
|
||||
#include <c10/util/UniqueVoidPtr.h>
|
||||
|
||||
#include <memory>
|
||||
|
|
@ -109,6 +110,9 @@ c10::intrusive_ptr<StorageImpl> lazy_clone_storage(StorageImpl& storage) {
|
|||
}
|
||||
|
||||
C10_API void materialize_cow_storage(StorageImpl& storage) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!c10::ParallelGuard::is_enabled(),
|
||||
"Materializing a storage in the loop function of at::parallel_for is forbidden");
|
||||
const at::DataPtr& data_ptr = storage.data_ptr();
|
||||
|
||||
auto* ctx = data_ptr.cast_context<cow::COWDeleterContext>(cow::cow_deleter);
|
||||
|
|
|
|||
19
c10/util/ParallelGuard.cpp
Normal file
19
c10/util/ParallelGuard.cpp
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#include <c10/util/ParallelGuard.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
thread_local bool in_at_parallel = false;
|
||||
|
||||
bool ParallelGuard::is_enabled() {
|
||||
return in_at_parallel;
|
||||
}
|
||||
|
||||
ParallelGuard::ParallelGuard(bool state) : previous_state_(is_enabled()) {
|
||||
in_at_parallel = state;
|
||||
}
|
||||
|
||||
ParallelGuard::~ParallelGuard() {
|
||||
in_at_parallel = previous_state_;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
20
c10/util/ParallelGuard.h
Normal file
20
c10/util/ParallelGuard.h
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
// RAII thread local guard that tracks whether code is being executed in
|
||||
// `at::parallel_for` or `at::parallel_reduce` loop function.
|
||||
class C10_API ParallelGuard {
|
||||
public:
|
||||
static bool is_enabled();
|
||||
|
||||
ParallelGuard(bool state);
|
||||
~ParallelGuard();
|
||||
|
||||
private:
|
||||
bool previous_state_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
|
@ -59,6 +59,7 @@ ignored_c_binding_in_graph_function_names = {
|
|||
"torch._C._data_address",
|
||||
"torch._C._is_cow_tensor",
|
||||
"torch._lazy_clone",
|
||||
"torch._test_parallel_materialize",
|
||||
"torch._C._storage_address",
|
||||
"torch._C._pickle_save",
|
||||
"torch.cuda._get_device_properties",
|
||||
|
|
|
|||
|
|
@ -540,6 +540,7 @@ aten::_test_optional_floatlist
|
|||
aten::_test_optional_floatlist.out
|
||||
aten::_test_optional_intlist
|
||||
aten::_test_optional_intlist.out
|
||||
aten::_test_parallel_materialize
|
||||
aten::_test_warn_in_autograd
|
||||
aten::_test_warn_in_autograd.out
|
||||
aten::_thnn_fused_gru_cell
|
||||
|
|
|
|||
|
|
@ -5211,6 +5211,48 @@ else:
|
|||
self.assertTrue(torch._C._is_cow_tensor(t))
|
||||
self.assertTrue(torch._C._is_cow_tensor(clone))
|
||||
|
||||
# This tests that if a COW materialization is attempted inside an
|
||||
# `at::parallel_for` loop function, then an error is raised. This test is
|
||||
# implemented in Python rather than C++ because the C++ tests are built
|
||||
# without multithreading support in `at::parallel_for`.
|
||||
@skipXLA
|
||||
@skipIfTorchDynamo("Torchdynamo fails and we do not need to test it here anyway")
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
|
||||
def test_parallel_cow_materialize_error(self, device, dtype):
|
||||
|
||||
def run(num_threads, num_parallel, skip_first, should_error):
|
||||
orig_num_threads = torch.get_num_threads()
|
||||
|
||||
try:
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
a = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)._lazy_clone()
|
||||
|
||||
if should_error:
|
||||
with self.assertRaisesRegex(RuntimeError, r'Materializing a storage'):
|
||||
torch._test_parallel_materialize(
|
||||
a, num_parallel, skip_first)
|
||||
else:
|
||||
torch._test_parallel_materialize(a, num_parallel, skip_first)
|
||||
|
||||
# Error should not raise in any case if the tensor is not COW
|
||||
b = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)
|
||||
torch._test_parallel_materialize(b, num_parallel, skip_first)
|
||||
|
||||
finally:
|
||||
torch.set_num_threads(orig_num_threads)
|
||||
|
||||
run(1, 1, False, True)
|
||||
run(1, 1, True, False)
|
||||
run(1, 10, False, True)
|
||||
run(1, 10, True, True)
|
||||
run(10, 1, False, True)
|
||||
run(10, 1, True, False)
|
||||
run(10, 10, False, True)
|
||||
run(10, 10, True, True)
|
||||
run(10, 2, False, True)
|
||||
run(10, 2, True, True)
|
||||
|
||||
# FIXME: move to test distributions
|
||||
@skipIfMps
|
||||
@dtypesIfCUDA(torch.float, torch.double, torch.half)
|
||||
|
|
|
|||
|
|
@ -330,6 +330,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||
aten.isposinf,
|
||||
aten.l1_loss,
|
||||
aten._lazy_clone,
|
||||
aten._test_parallel_materialize,
|
||||
aten.leaky_relu_,
|
||||
aten.leaky_relu_backward,
|
||||
aten.lerp,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user