From 33403f484874dba95cae2a91736a9fd04cd70eba Mon Sep 17 00:00:00 2001 From: Jonathan Colen Date: Tue, 25 Jan 2022 12:45:58 -0800 Subject: [PATCH] edge_order check in torch.gradient only applies to dim argument (#67926) Summary: Fixes https://github.com/pytorch/pytorch/issues/67919 The compatibility check on `edge_order` in `pre_check_gradient` now looks only at dim argument if it is present, otherwise it checks all dimensions. Previously, it would check all dimensions regardless of the dim argument and throw unnecessary errors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/67926 Reviewed By: albanD Differential Revision: D33760621 Pulled By: mruberry fbshipit-source-id: d490cd8610c68ff3787e670fc947de3cbf2db062 (cherry picked from commit 45bc56de9e287f715186378682e22bc6ac7a6ccc) --- aten/src/ATen/native/ReduceOps.cpp | 10 +++++++--- test/test_torch.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index b128e4b1b9d..38eafedbeeb 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -841,13 +841,17 @@ void pre_check_gradient(const Tensor& self, c10::optional spacing_size, "torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." ); } TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2."); - for (const auto i : c10::irange(self.dim())) { - TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1"); - } if (dim.has_value()) { // The following function get called to check whether dim argument satisfies prerequisites. // The output of the function is not used for the computation of gradient. dim_list_to_bitset(dim.value(), self.dim()); + for (const auto i : c10::irange(dim.value().size())) { + TORCH_CHECK(self.size(dim.value()[i]) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1"); + } + } else { + for (const auto i : c10::irange(self.dim())) { + TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1"); + } } } diff --git a/test/test_torch.py b/test/test_torch.py index 28e03ccc78a..5c66d5bc42c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -2463,6 +2463,8 @@ else: ((4, 5, 3, 4, 3), (1, 2)), ((4, 3, 6, 5, 3), (2, 4)), ((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)), + ((1, 3, 3), (1, 2)), + ((1, 5), (1,)), ) for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],