mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove some spurious warnings fixing (#72352)
Summary: Fixes https://github.com/pytorch/pytorch/issues/70389 Pull Request resolved: https://github.com/pytorch/pytorch/pull/72352 Reviewed By: jbschlosser Differential Revision: D34011981 Pulled By: albanD fbshipit-source-id: 55bedc8a40929bc5b49cb6d7d7d51a3750f2ff27
This commit is contained in:
parent
95cc102bc2
commit
a6657a9071
|
|
@ -13,9 +13,22 @@ namespace utils {
|
||||||
inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& variable) {
|
inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& variable) {
|
||||||
TORCH_INTERNAL_ASSERT(!grad.is_sparse());
|
TORCH_INTERNAL_ASSERT(!grad.is_sparse());
|
||||||
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
|
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
|
||||||
return variable.is_non_overlapping_and_dense() ?
|
if (variable.is_non_overlapping_and_dense()) {
|
||||||
(grad.strides() == variable.strides()) :
|
// Only look at stride for dimensions that are not of size 1
|
||||||
grad.is_contiguous(at::MemoryFormat::Contiguous);
|
const auto& grad_sizes = grad.sizes();
|
||||||
|
const auto& grad_strides = grad.strides();
|
||||||
|
const auto& variable_strides = variable.strides();
|
||||||
|
for (const auto idx : c10::irange(grad_sizes.size())) {
|
||||||
|
if (grad_sizes[idx] != 1) {
|
||||||
|
if (grad_strides[idx] != variable_strides[idx]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return grad.is_contiguous(at::MemoryFormat::Contiguous);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates a clone of new_grad that obeys the contract with variable.
|
// Creates a clone of new_grad that obeys the contract with variable.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user