mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
edge_order check in torch.gradient only applies to dim argument (#67926)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/67919
The compatibility check on `edge_order` in `pre_check_gradient` now looks only at dim argument if it is present, otherwise it checks all dimensions.
Previously, it would check all dimensions regardless of the dim argument and throw unnecessary errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67926
Reviewed By: albanD
Differential Revision: D33760621
Pulled By: mruberry
fbshipit-source-id: d490cd8610c68ff3787e670fc947de3cbf2db062
(cherry picked from commit 45bc56de9e)
This commit is contained in:
parent
f3e81f3eed
commit
33403f4848
|
|
@ -841,13 +841,17 @@ void pre_check_gradient(const Tensor& self, c10::optional<int64_t> spacing_size,
|
|||
"torch.gradient expected spacing to be unspecified, a scalar or it's spacing and dim arguments to have the same length, but got a spacing argument of length ", spacing_size.value(), " and a dim argument of length ", dim.value().size(), "." );
|
||||
}
|
||||
TORCH_CHECK(edge_order == 1 || edge_order == 2, "torch.gradient only supports edge_order=1 and edge_order=2.");
|
||||
for (const auto i : c10::irange(self.dim())) {
|
||||
TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
|
||||
}
|
||||
if (dim.has_value()) {
|
||||
// The following function get called to check whether dim argument satisfies prerequisites.
|
||||
// The output of the function is not used for the computation of gradient.
|
||||
dim_list_to_bitset(dim.value(), self.dim());
|
||||
for (const auto i : c10::irange(dim.value().size())) {
|
||||
TORCH_CHECK(self.size(dim.value()[i]) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
|
||||
}
|
||||
} else {
|
||||
for (const auto i : c10::irange(self.dim())) {
|
||||
TORCH_CHECK(self.size(i) >= edge_order + 1, "torch.gradient expected each dimension size to be at least edge_order+1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2463,6 +2463,8 @@ else:
|
|||
((4, 5, 3, 4, 3), (1, 2)),
|
||||
((4, 3, 6, 5, 3), (2, 4)),
|
||||
((4, 3, 3, 5, 3), (0, 1, 2, 3, 4)),
|
||||
((1, 3, 3), (1, 2)),
|
||||
((1, 5), (1,)),
|
||||
)
|
||||
|
||||
for case, contig, edge_order, space_fn in product(test_cases, [True, False], [1, 2],
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user