edge_order check in torch.gradient only applies to dim argument (#67926)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/67919

The compatibility check on `edge_order` in `pre_check_gradient` now looks only at dim argument if it is present, otherwise it checks all dimensions.

Previously, it would check all dimensions regardless of the dim argument and throw unnecessary errors.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/67926

Reviewed By: albanD

Differential Revision: D33760621

Pulled By: mruberry

fbshipit-source-id: d490cd8610c68ff3787e670fc947de3cbf2db062
(cherry picked from commit 45bc56de9e)
This commit is contained in:
Jonathan Colen 2022-01-25 12:45:58 -08:00 committed by PyTorch MergeBot
parent f3e81f3eed
commit 33403f4848
2 changed files with 9 additions and 3 deletions

View File

@ -841,13 +841,17 @@ void pre_check_gradient(const Tensor& self, c10::optional<int64_t> spacing_size,
"torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." );
}
TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2.");
for (const auto i : c10::irange(self.dim())) {
TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
}
if (dim.has_value()) {
// The following function get called to check whether dim argument satisfies prerequisites.
// The output of the function is not used for the computation of gradient.
dim_list_to_bitset(dim.value(), self.dim());
for (const auto i : c10::irange(dim.value().size())) {
TORCH_CHECK(self.size(dim.value()[i]) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
}
} else {
for (const auto i : c10::irange(self.dim())) {
TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
}
}
}

View File

@ -2463,6 +2463,8 @@ else:
((4, 5, 3, 4, 3), (1, 2)),
((4, 3, 6, 5, 3), (2, 4)),
((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)),
((1, 3, 3), (1, 2)),
((1, 5), (1,)),
)
for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],