mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34011981: [pytorch][PR] remove some spurious warnings fixing
Test Plan: revert-hammer Differential Revision: D34011981 (1bad3c4a84) Original commit changeset: 55bedc8a4092 Original Phabricator Diff: D34011981 (1bad3c4a84) fbshipit-source-id: 216643e251597cd7086e7854426f4f189a77adc9 (cherry picked from commitbb39550500)
This commit is contained in:
parent
d50211860a
commit
5654b68731
|
|
@ -13,22 +13,9 @@ namespace utils {
|
|||
inline bool obeys_layout_contract(const at::Tensor& grad, const at::Tensor& variable) {
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_sparse());
|
||||
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
|
||||
if (variable.is_non_overlapping_and_dense()) {
|
||||
// Only look at stride for dimensions that are not of size 1
|
||||
const auto& grad_sizes = grad.sizes();
|
||||
const auto& grad_strides = grad.strides();
|
||||
const auto& variable_strides = variable.strides();
|
||||
for (const auto idx : c10::irange(grad_sizes.size())) {
|
||||
if (grad_sizes[idx] != 1) {
|
||||
if (grad_strides[idx] != variable_strides[idx]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
return grad.is_contiguous(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
return variable.is_non_overlapping_and_dense() ?
|
||||
(grad.strides() == variable.strides()) :
|
||||
grad.is_contiguous(at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Creates a clone of new_grad that obeys the contract with variable.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user