diff --git a/aten/src/ATen/Parallel-inl.h b/aten/src/ATen/Parallel-inl.h index 62f287fc33c..a5e682281ab 100644 --- a/aten/src/ATen/Parallel-inl.h +++ b/aten/src/ATen/Parallel-inl.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include namespace at { @@ -24,13 +25,19 @@ inline void parallel_for( at::get_num_threads() > 1); if (!use_parallel) { internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); f(begin, end); return; } - internal::invoke_parallel(begin, end, grain_size, f); + internal::invoke_parallel( + begin, end, grain_size, [&](int64_t begin, int64_t end) { + c10::ParallelGuard guard(true); + f(begin, end); + }); #else internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); f(begin, end); #endif } @@ -56,6 +63,7 @@ inline scalar_t parallel_reduce( max_threads > 1); if (!use_parallel) { internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); return f(begin, end, ident); } @@ -66,6 +74,7 @@ inline scalar_t parallel_reduce( grain_size, [&](const int64_t my_begin, const int64_t my_end) { const auto tid = at::get_thread_num(); + c10::ParallelGuard guard(true); results[tid] = f(my_begin, my_end, ident); }); @@ -76,6 +85,7 @@ inline scalar_t parallel_reduce( return result; #else internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); return f(begin, end, ident); #endif } diff --git a/aten/src/ATen/native/ConvolutionMM3d.cpp b/aten/src/ATen/native/ConvolutionMM3d.cpp index c194721acd4..b0a97536f68 100644 --- a/aten/src/ATen/native/ConvolutionMM3d.cpp +++ b/aten/src/ATen/native/ConvolutionMM3d.cpp @@ -430,12 +430,12 @@ void slow_conv3d_backward_out_cpu_template( AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, input.scalar_type(), "slow_conv3d_cpu_grad_input", [&] { + auto grad_input_a = grad_input.accessor(); + auto grad_output_a = grad_output_contiguous.accessor(); + auto fgrad_input_a = fgrad_input.accessor(); + auto weight_2d_a = weight2d.accessor(); at::parallel_for(0, batch_size, CONV3D_GRAIN_SALT, [&](int64_t start, int64_t end) { - auto grad_input_a = grad_input.accessor(); - auto grad_output_a = grad_output_contiguous.accessor(); - auto fgrad_input_a = fgrad_input.accessor(); - auto weight_2d_a = weight2d.accessor(); for (const auto t : c10::irange(start, end)) { auto grad_input_t = grad_input_a[t]; diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1a995f0d350..55d6f2788d2 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1835,6 +1835,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens r, r, batch1.select(0, b), batch2.select(0, b), 0, 1); } }; + // Materialize if COW, since we cannot do so during parallel_for + self_or_result.mutable_data_ptr(); at::parallel_for(0, bs, 1, bmm_out_fn); } else { for (const auto b : c10::irange(bs)) { @@ -1851,6 +1853,8 @@ static inline void bmm_out_or_baddbmm_(const Tensor& self_or_result_, const Tens batch1.select(0, b), batch2.select(0, b), beta, alpha); } }; + // Materialize if COW, since we cannot do so during parallel_for + self_or_result.mutable_data_ptr(); at::parallel_for(0, bs, 1, bmm_fn); } else { for (const auto b : c10::irange(bs)) { diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index b6ad40b344b..227b3783260 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -260,6 +260,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ auto gp = grad.permute({1,0,2}); auto grad_a_global = gp.accessor(); auto targets_data = targets.data_ptr(); + auto grad_out_a = grad_out.accessor(); auto create_fill_iterator = [](const Tensor& tensor, IntArrayRef squash_dims) { return TensorIteratorConfig() @@ -366,7 +367,7 @@ Tensor ctc_loss_backward_cpu_template(const Tensor& grad_out, const Tensor& log_ // now we wrap up the calculation by adding in the remaining items of eq (16) // this could be a great target for further vectorization. // grad is the output gradient, nll is the loss. Note that the likelihood -nll is the Z of eq (16) - scalar_t gr = grad_out.accessor()[b]; + scalar_t gr = grad_out_a[b]; for (const auto t : c10::irange(input_length)) { // or go for the full thing? for (const auto c : c10::irange(num_labels)) { scalar_t& res = grad_a[t][c]; diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index bd4a168552b..567a3754e99 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -298,6 +298,9 @@ void slow_conv_transpose2d_out_cpu_template( } columns.zero_(); + // Materialize if COW, since we cannot do so during parallel_for + output.mutable_data_ptr(); + AT_DISPATCH_FLOATING_TYPES_AND3(at::ScalarType::Long, at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "slow_conv_transpose2d_out_cpu", [&] { diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp index 73fd1c1a941..c2ccdc7ddfe 100644 --- a/aten/src/ATen/native/QuantizedLinear.cpp +++ b/aten/src/ATen/native/QuantizedLinear.cpp @@ -125,6 +125,9 @@ Tensor fbgemm_linear_int8_weight_fp32_activation( auto& pack_b = cpp_custom_type_hack::cast>(packed); + int32_t* col_offsets_data = col_offsets.data_ptr(); + float* bias_contig_data = bias_contig.data_ptr(); + const int num_tasks = at::get_num_threads(); at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { // This operation does the following: @@ -162,8 +165,8 @@ Tensor fbgemm_linear_int8_weight_fp32_activation( /*Aq_zero_point=*/q_params.zero_point, /*Bq_zero_point=*/&weight_zero_point_int32, /*row_offsets=*/pack_a.getRowOffsetBuffer(), - /*col_offsets=*/col_offsets.data_ptr(), - /*bias=*/bias_contig.data_ptr(), + /*col_offsets=*/col_offsets_data, + /*bias=*/bias_contig_data, /*nCol=*/N); // Do the GEMM fbgemm::fbgemmPacked( diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 97b1771aa05..36fde7aecd7 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -2293,6 +2293,8 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { return result; } + auto out_accessor = result.accessor(); + // Pass 2: Write indexes AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( kComplexHalf, kHalf, kBFloat16, kBool, self.scalar_type(), "nonzero_cpu", [&] { @@ -2313,7 +2315,6 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { } } - auto out_accessor = result.accessor(); auto out_ptr = out_accessor[thread_count_nonzero[tid]].data(); auto loop = [&](char** data, const int64_t* strides, int64_t n1, int64_t n2) { diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index df4de0d14fb..c0d13c4bfdf 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -2066,7 +2066,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr(); auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr(); auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr(); - const auto* ptr_src = src.data_ptr() + start; + const auto* ptr_src = src.const_data_ptr() + start; for (const auto i : c10::irange(start, end)) { const auto src_val = *ptr_src++; @@ -2124,9 +2124,9 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in const auto start = tid * chunk_size_src; const auto end = std::min(start + chunk_size_src, src_len); const auto tid_offset = thread_offsets.const_data_ptr()[tid]; - const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr(); - const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr(); - const auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr(); + const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).const_data_ptr(); + const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).const_data_ptr(); + const auto* ptr_tid_int_counts = int_counts.select(0, tid).const_data_ptr(); auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset; auto* ptr_tid_selected_src = ptr_selected_src + tid_offset; @@ -2324,9 +2324,9 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in const auto end = std::min(start + chunk_size, src_len); auto* ptr_src_tid = ptr_src + start; const auto* ptr_src_counts_per_thread - = src_counts_per_thread.select(0, tid).data_ptr(); + = src_counts_per_thread.select(0, tid).const_data_ptr(); const auto* ptr_src_offset_counts_per_thread - = src_offset_counts_per_thread.select(0, tid).data_ptr(); + = src_offset_counts_per_thread.select(0, tid).const_data_ptr(); auto tid_counts = at::zeros({size}, src.options()); auto* ptr_tid_counts = tid_counts.data_ptr(); diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index 6dd2d1aa555..e2fce123035 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -111,6 +113,21 @@ Tensor _test_check_tensor(const Tensor& self) { return self.clone(); } +Tensor _test_parallel_materialize(const Tensor& self, int64_t num_parallel, bool skip_first) { + at::parallel_for(0, num_parallel, 1, [&](int64_t begin, int64_t end){ + // NOTE: skip_first is meant to avoid triggering the materialization from + // the first thread, to ensure that the subthreads throw the error + // correctly. On some platforms, the first thread is the main thread and it + // begins executing the loop function much earlier than the subthreads. + if (skip_first && begin == 0 && end == 1) { + return; + } else { + self.mutable_data_ptr(); + } + }); + return self; +} + } // namespace at::native namespace at::functionalization { diff --git a/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp b/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp index 3a34ad3f7a6..9a618c143ba 100644 --- a/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp +++ b/aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp @@ -292,16 +292,21 @@ Tensor _convolution_depthwise3x3_winograd( bias_potentially_undefined : at::zeros({kernel_sizes[0]}, input.options()); + auto input_data = input.data_ptr(); + auto kernel_data = kernel.data_ptr(); + auto bias_data = bias.data_ptr(); + auto output_data = output.data_ptr(); + at::parallel_for(0, args.batch * args.out_channels, 0, [&](int64_t start, int64_t end) { for (const auto k : c10::irange(start, end)) { const int64_t g = k % args.out_channels; const int64_t i = k / (args.out_channels / groups); convolution_depthwise3x3_winograd_impl( args, - input.data_ptr() + i * input_hxw, - kernel.data_ptr() + g * 3 * 3, - bias.data_ptr() + g, - output.data_ptr() + k * output_hxw); + input_data + i * input_hxw, + kernel_data + g * 3 * 3, + bias_data + g, + output_data + k * output_hxw); } }); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 749415479c2..961c706154a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -14120,6 +14120,12 @@ # It is undocumented and should not be used outside of tests. - func: _test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor +# Note: for testing COW materialization within `at::parallel_for` loop function +- func: _test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor + variants: function + dispatch: + CompositeExplicitAutograd: _test_parallel_materialize + # Note: this function is only for testing. - func: _test_optional_intlist(Tensor values, int[]? addends) -> Tensor python_module: nn diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 20689987515..22a27cad761 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -367,7 +367,7 @@ inline Tensor wrap_tensor_node( if (tensor_node.children(i).numel() > 0) { memcpy( nt_buffer.mutable_data_ptr() + start_offsets[i], - tensor_node.children(i).data_ptr(), + tensor_node.children(i).const_data_ptr(), tensor_node.children(i).numel() * sizeof(scalar_t)); } } diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp index 0057fea54c2..7e5083057a0 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp @@ -491,7 +491,7 @@ at::Tensor& embedding_bag_byte_impl( /*offsets_or_lengths=*/offsets_data + start_idx, /*weights=*/ per_sample_weights_ - ? per_sample_weights_.value().data_ptr() + + ? per_sample_weights_.value().const_data_ptr() + offsets_data[start_idx] : nullptr, /*out=*/output_data + start_idx * D); diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 728b672dd47..7aa1058b5fd 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -123,6 +123,8 @@ at::Tensor& PackedLinearWeight::apply_impl( // Allocate a buffer for fbgemmPacked to use auto buffer = at::empty(out_sizes, output.options().dtype(at::kInt)); + auto output_data = reinterpret_cast(output.data_ptr()); + int num_tasks = at::get_num_threads(); at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { for (const auto task_id : c10::irange(begin, end)) { @@ -184,7 +186,7 @@ at::Tensor& PackedLinearWeight::apply_impl( fbgemm::fbgemmPacked( /*packA=*/packA, /*packB=*/*packB, - /*C=*/reinterpret_cast(output.data_ptr()), + /*C=*/output_data, /*C_buffer=*/buffer.data_ptr(), /*ldc=*/N, /*outProcess=*/outputProcObj, @@ -220,7 +222,7 @@ at::Tensor& PackedLinearWeight::apply_impl( fbgemm::fbgemmPacked( /*packA=*/packA, /*packB=*/*packB, - /*C=*/reinterpret_cast(output.data_ptr()), + /*C=*/output_data, /*C_buffer=*/buffer.data_ptr(), /*ldc=*/N, /*outProcess=*/outputProcObj, @@ -358,6 +360,8 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl output.options().dtype(at::kInt), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto output_data = output.data_ptr(); + int num_tasks = at::get_num_threads(); at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { fbgemm::PackAWithQuantRowOffset packA( @@ -396,7 +400,7 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl fbgemm::fbgemmPacked( /*packA=*/packA, /*packB=*/*packB, - /*C=*/output.data_ptr(), + /*C=*/output_data, /*C_buffer=*/buffer.data_ptr(), /*ldc=*/N, /*outProcess=*/outputProcObj, @@ -428,7 +432,7 @@ at::Tensor PackedLinearWeight::apply_with_input_q_dq_qweight_dq_output_fp32_impl fbgemm::fbgemmPacked( /*packA=*/packA, /*packB=*/*packB, - /*C=*/output.data_ptr(), + /*C=*/output_data, /*C_buffer=*/buffer.data_ptr(), /*ldc=*/N, /*outProcess=*/outputProcObj, diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 026fb1f74ee..cd17dbfe1d0 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -423,6 +423,8 @@ at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl( // Resize output Tensor output.resize_(output_sizes); + auto output_data = output.data_ptr(); + int num_tasks = at::get_num_threads(); at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) { for (const auto task_id : c10::irange(begin, end)) { @@ -433,7 +435,7 @@ at::Tensor& PackedLinearWeightFp16::apply_dynamic_impl( /*A=*/input_ptr, /*Bp=*/packed_weight_fp16, /*beta=*/0.0f, - /*C=*/output.data_ptr(), + /*C=*/output_data, /*thread_id=*/static_cast(task_id), /*num_threads=*/num_tasks); } diff --git a/c10/core/impl/COW.cpp b/c10/core/impl/COW.cpp index ad92a8d6717..bd15fb949e8 100644 --- a/c10/core/impl/COW.cpp +++ b/c10/core/impl/COW.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -109,6 +110,9 @@ c10::intrusive_ptr lazy_clone_storage(StorageImpl& storage) { } C10_API void materialize_cow_storage(StorageImpl& storage) { + TORCH_INTERNAL_ASSERT( + !c10::ParallelGuard::is_enabled(), + "Materializing a storage in the loop function of at::parallel_for is forbidden"); const at::DataPtr& data_ptr = storage.data_ptr(); auto* ctx = data_ptr.cast_context(cow::cow_deleter); diff --git a/c10/util/ParallelGuard.cpp b/c10/util/ParallelGuard.cpp new file mode 100644 index 00000000000..29d1b88dae3 --- /dev/null +++ b/c10/util/ParallelGuard.cpp @@ -0,0 +1,19 @@ +#include + +namespace c10 { + +thread_local bool in_at_parallel = false; + +bool ParallelGuard::is_enabled() { + return in_at_parallel; +} + +ParallelGuard::ParallelGuard(bool state) : previous_state_(is_enabled()) { + in_at_parallel = state; +} + +ParallelGuard::~ParallelGuard() { + in_at_parallel = previous_state_; +} + +} // namespace c10 diff --git a/c10/util/ParallelGuard.h b/c10/util/ParallelGuard.h new file mode 100644 index 00000000000..e28289745ae --- /dev/null +++ b/c10/util/ParallelGuard.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace c10 { + +// RAII thread local guard that tracks whether code is being executed in +// `at::parallel_for` or `at::parallel_reduce` loop function. +class C10_API ParallelGuard { + public: + static bool is_enabled(); + + ParallelGuard(bool state); + ~ParallelGuard(); + + private: + bool previous_state_; +}; + +} // namespace c10 diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 320b035f7ca..c3894752032 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -59,6 +59,7 @@ ignored_c_binding_in_graph_function_names = { "torch._C._data_address", "torch._C._is_cow_tensor", "torch._lazy_clone", + "torch._test_parallel_materialize", "torch._C._storage_address", "torch._C._pickle_save", "torch.cuda._get_device_properties", diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index bc5d0d0154c..cf607d6fd7a 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -540,6 +540,7 @@ aten::_test_optional_floatlist aten::_test_optional_floatlist.out aten::_test_optional_intlist aten::_test_optional_intlist.out +aten::_test_parallel_materialize aten::_test_warn_in_autograd aten::_test_warn_in_autograd.out aten::_thnn_fused_gru_cell diff --git a/test/test_torch.py b/test/test_torch.py index dbab8da4230..105488720a6 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5211,6 +5211,48 @@ else: self.assertTrue(torch._C._is_cow_tensor(t)) self.assertTrue(torch._C._is_cow_tensor(clone)) + # This tests that if a COW materialization is attempted inside an + # `at::parallel_for` loop function, then an error is raised. This test is + # implemented in Python rather than C++ because the C++ tests are built + # without multithreading support in `at::parallel_for`. + @skipXLA + @skipIfTorchDynamo("Torchdynamo fails and we do not need to test it here anyway") + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_parallel_cow_materialize_error(self, device, dtype): + + def run(num_threads, num_parallel, skip_first, should_error): + orig_num_threads = torch.get_num_threads() + + try: + torch.set_num_threads(num_threads) + + a = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype)._lazy_clone() + + if should_error: + with self.assertRaisesRegex(RuntimeError, r'Materializing a storage'): + torch._test_parallel_materialize( + a, num_parallel, skip_first) + else: + torch._test_parallel_materialize(a, num_parallel, skip_first) + + # Error should not raise in any case if the tensor is not COW + b = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) + torch._test_parallel_materialize(b, num_parallel, skip_first) + + finally: + torch.set_num_threads(orig_num_threads) + + run(1, 1, False, True) + run(1, 1, True, False) + run(1, 10, False, True) + run(1, 10, True, True) + run(10, 1, False, True) + run(10, 1, True, False) + run(10, 10, False, True) + run(10, 10, True, True) + run(10, 2, False, True) + run(10, 2, True, True) + # FIXME: move to test distributions @skipIfMps @dtypesIfCUDA(torch.float, torch.double, torch.half) diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 2a4e9157b34..9146f40dcce 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -330,6 +330,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: aten.isposinf, aten.l1_loss, aten._lazy_clone, + aten._test_parallel_materialize, aten.leaky_relu_, aten.leaky_relu_backward, aten.lerp,