mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Revert D31227448: [pytorch][PR] fixing sorting in stride indices" (#66176)
Summary: enabling https://github.com/pytorch/pytorch/issues/63940 Pull Request resolved: https://github.com/pytorch/pytorch/pull/66176 Reviewed By: ngimel Differential Revision: D31423920 Pulled By: dzhulgakov fbshipit-source-id: 06b1e0f757f4fb5b31ee1fa464bcd689df919b9c
This commit is contained in:
parent
74477ba243
commit
321345d7c9
|
|
@ -10,6 +10,26 @@
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
inline bool is_contiguous_strides(
|
||||||
|
const IntArrayRef sizes,
|
||||||
|
const IntArrayRef strides) {
|
||||||
|
int n_dim = static_cast<int>(sizes.size());
|
||||||
|
|
||||||
|
if (n_dim == 0 || strides[n_dim-1] != 1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = n_dim - 2; i >= 0; i--) {
|
||||||
|
if (strides[i] != strides[i+1] * sizes[i+1]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
TypeVerbosity type_verbosity() {
|
TypeVerbosity type_verbosity() {
|
||||||
static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
|
static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
|
||||||
static TypeVerbosity verbosity = c_verbosity ?
|
static TypeVerbosity verbosity = c_verbosity ?
|
||||||
|
|
@ -1407,21 +1427,67 @@ VaryingShape<Stride> TensorType::computeStrideProps(
|
||||||
at::IntArrayRef sizes,
|
at::IntArrayRef sizes,
|
||||||
at::IntArrayRef strides,
|
at::IntArrayRef strides,
|
||||||
bool tensor_contiguity) {
|
bool tensor_contiguity) {
|
||||||
std::vector<size_t> stride_indices(sizes.size());
|
int n_dim = static_cast<int>(sizes.size());
|
||||||
|
std::vector<size_t> stride_indices(n_dim);
|
||||||
|
|
||||||
|
// Sorting strides in ascending order
|
||||||
|
// Example:
|
||||||
|
// Prior to sorting
|
||||||
|
// Idx: [0, 1, 2, 3]
|
||||||
|
// sizes: [8, 1, 10, 16]
|
||||||
|
// Strides: [160, 1, 16, 1]
|
||||||
|
// After sorting
|
||||||
|
// Idx: [1, 3, 2, 0]
|
||||||
|
// sizes: [1, 16, 10, 8]
|
||||||
|
// Strides: [1, 1, 16, 160]
|
||||||
|
//
|
||||||
|
// The logic below follows what TensorIterator uses in its logic:
|
||||||
|
// 1. Fast_set_up is the short-cut to identify a. channels_last and
|
||||||
|
// b. contiguous format, which is what we have in the below logic.
|
||||||
|
// 2. In more generla cases, it does best effort to preserve permutatoin.
|
||||||
|
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
|
||||||
|
// case 1.a. short cut channels last
|
||||||
|
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
|
||||||
|
stride_indices[0] = 1;
|
||||||
|
stride_indices[n_dim - 1] = 0;
|
||||||
|
} else if (is_contiguous_strides(sizes, strides)) {
|
||||||
|
// case 1.b. short cut contiguous
|
||||||
|
std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
|
||||||
|
} else {
|
||||||
std::iota(stride_indices.begin(), stride_indices.end(), 0);
|
std::iota(stride_indices.begin(), stride_indices.end(), 0);
|
||||||
|
// case 2.
|
||||||
std::sort(
|
//
|
||||||
stride_indices.begin(),
|
// For broadcasted dimension where stride is 0, we have to stick to
|
||||||
stride_indices.end(),
|
// TensorIterator behavior in eager, where they introduce an ambiguous
|
||||||
[&strides](const int& a, const int& b) {
|
// comparison result to preserve permutation by best effort.
|
||||||
// break ties in case of unsqueezed dims
|
// For more details, see NOTE: [Computing output strides]
|
||||||
// i.e. (1, 1, 5)
|
auto should_swap = [&](size_t a, size_t b) {
|
||||||
if (strides[a] == strides[b]) {
|
if (strides[a] == 0 || strides[b] == 0) {
|
||||||
return a > b;
|
return 0;
|
||||||
|
} else if (strides[a] < strides[b]) {
|
||||||
|
return -1;
|
||||||
|
} else if (strides[a] > strides[b]) {
|
||||||
|
return 1;
|
||||||
|
} else { // strides[a] == strides[b]
|
||||||
|
if (sizes[a] < sizes[b] || a > b ) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
};
|
||||||
|
for (int i = 1; i < n_dim; i++) {
|
||||||
|
int dim1 = i;
|
||||||
|
for (int dim0 = i - 1; dim0 >= 0; dim0--) {
|
||||||
|
int comparison = should_swap(stride_indices[dim0], stride_indices[dim1]);
|
||||||
|
if (comparison > 0) {
|
||||||
|
std::swap(stride_indices[dim0], stride_indices[dim1]);
|
||||||
|
dim1 = dim0;
|
||||||
|
} else if (comparison < 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return strides[a] < strides[b];
|
|
||||||
});
|
|
||||||
|
|
||||||
std::vector<Stride> stride_properties;
|
std::vector<Stride> stride_properties;
|
||||||
for (size_t i = 0; i < stride_indices.size(); i++) {
|
for (size_t i = 0; i < stride_indices.size(); i++) {
|
||||||
bool contiguous_ = tensor_contiguity;
|
bool contiguous_ = tensor_contiguity;
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,8 @@ list(APPEND ATen_CPU_TEST_SRCS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/ivalue_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/ivalue_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/vmap_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/vmap_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/dispatch_key_set_test.cpp)
|
${CMAKE_CURRENT_SOURCE_DIR}/dispatch_key_set_test.cpp
|
||||||
|
${CMAKE_CURRENT_SOURCE_DIR}/stride_properties_test.cpp)
|
||||||
|
|
||||||
list(APPEND ATen_CUDA_TEST_SRCS
|
list(APPEND ATen_CUDA_TEST_SRCS
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu
|
${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu
|
||||||
|
|
|
||||||
69
aten/src/ATen/test/stride_properties_test.cpp
Normal file
69
aten/src/ATen/test/stride_properties_test.cpp
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
using namespace at;
|
||||||
|
|
||||||
|
// TODO: failing sizes {4, 1, 4, 1}
|
||||||
|
std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
|
||||||
|
|
||||||
|
inline bool CheckStrideIndices(const Tensor& t, at::MemoryFormat format) {
|
||||||
|
size_t n_dim = t.dim();
|
||||||
|
std::vector<size_t> stride_indices(n_dim);
|
||||||
|
if (format == at::MemoryFormat::ChannelsLast) {
|
||||||
|
// stride_indices_ should be {1, n-1, n-2, ..., 2, 0}
|
||||||
|
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
|
||||||
|
stride_indices[0] = 1;
|
||||||
|
stride_indices[n_dim - 1] = 0;
|
||||||
|
} else if (format == at::MemoryFormat::Contiguous) {
|
||||||
|
// stride_indices_ should be {n-1, n-2, n-3, ..., 0}
|
||||||
|
std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
|
||||||
|
} else {
|
||||||
|
TORCH_INTERNAL_ASSERT(false, "not recognized memory format");
|
||||||
|
}
|
||||||
|
|
||||||
|
// testing computeStrideProps with `IValue ival(t)` somehow doesn't work on CI
|
||||||
|
// with onnx; The function works fine within, but stride properties is somehow
|
||||||
|
// altered in ival->type()->cast<TensorType>();
|
||||||
|
auto tt = TensorType::create(c10::nullopt, c10::nullopt, t.sizes(), t.strides(), c10::nullopt);
|
||||||
|
TORCH_INTERNAL_ASSERT(tt->stride_properties().isComplete(), "complete stride properties is needed for the test");
|
||||||
|
|
||||||
|
auto index_iter = stride_indices.begin();
|
||||||
|
for (const auto& opt_stride : *tt->stride_properties().sizes()) {
|
||||||
|
if (*index_iter++ != opt_stride->stride_index_.value()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StridePropertiesTest, StrideIndicesTest) {
|
||||||
|
// NOLINTNEXTLINE(performance-for-range-copy)
|
||||||
|
for (const auto& size : sizes) {
|
||||||
|
Tensor t = at::rand(size);
|
||||||
|
for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
|
||||||
|
t.resize_(size, memory_format);
|
||||||
|
EXPECT_TRUE(CheckStrideIndices(t, memory_format));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StridePropertiesTest, ZeroStrideIndicesEagerConsistencyTest) {
|
||||||
|
auto permuted_tensor = at::rand({6, 3, 1, 5, 2}).permute({0, 3, 2, 1, 4}); // permute dim-1 & dim-3
|
||||||
|
auto tensor = permuted_tensor.expand({6, 5, 4, 3, 2}); // expand dim-2
|
||||||
|
|
||||||
|
auto temp = TensorType::create(c10::nullopt, c10::nullopt, tensor.sizes(), tensor.strides(), c10::nullopt);
|
||||||
|
|
||||||
|
// TensorIterator would preserve stride order, this is the eager reference
|
||||||
|
auto eager_tensor = tensor.relu();
|
||||||
|
auto ref_type = TensorType::create(c10::nullopt, c10::nullopt, eager_tensor.sizes(), eager_tensor.strides(), c10::nullopt);
|
||||||
|
|
||||||
|
TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
|
||||||
|
temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
|
||||||
|
auto ref_iter = (*(ref_type->stride_properties().sizes())).begin();
|
||||||
|
for (const auto& opt_stride : *temp->stride_properties().sizes()) {
|
||||||
|
EXPECT_TRUE(opt_stride->stride_index_.value() == (*ref_iter)->stride_index_.value());
|
||||||
|
ref_iter++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -124,8 +124,10 @@ bool complyWith(
|
||||||
if (j != 0 && inner_dim != -1) {
|
if (j != 0 && inner_dim != -1) {
|
||||||
// we are not looking at dim-j, but dim-sorted_index, which
|
// we are not looking at dim-j, but dim-sorted_index, which
|
||||||
// is the j-th fastest dim;
|
// is the j-th fastest dim;
|
||||||
// TODO: merge this with above and put a long comment there
|
// Note: we ignore 0-stride dimension, since eager logic on stride
|
||||||
if (t_strides[sorted_index] < t_strides[inner_dim]) {
|
// indices is ambiguous
|
||||||
|
if (t_strides[sorted_index] != 0 && t_strides[inner_dim] != 0 &&
|
||||||
|
t_strides[sorted_index] < t_strides[inner_dim]) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user