Revert "Revert D31227448: [pytorch][PR] fixing sorting in stride indices" (#66176)

Summary:
enabling https://github.com/pytorch/pytorch/issues/63940

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66176

Reviewed By: ngimel

Differential Revision: D31423920

Pulled By: dzhulgakov

fbshipit-source-id: 06b1e0f757f4fb5b31ee1fa464bcd689df919b9c
This commit is contained in:
jiej 2021-10-07 22:05:35 -07:00 committed by Facebook GitHub Bot
parent 74477ba243
commit 321345d7c9
4 changed files with 154 additions and 16 deletions

View File

@ -10,6 +10,26 @@
namespace c10 { namespace c10 {
namespace {
inline bool is_contiguous_strides(
const IntArrayRef sizes,
const IntArrayRef strides) {
int n_dim = static_cast<int>(sizes.size());
if (n_dim == 0 || strides[n_dim-1] != 1) {
return false;
}
for (int i = n_dim - 2; i >= 0; i--) {
if (strides[i] != strides[i+1] * sizes[i+1]) {
return false;
}
}
return true;
}
} // namespace
TypeVerbosity type_verbosity() { TypeVerbosity type_verbosity() {
static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY"); static const char* c_verbosity = std::getenv("PYTORCH_JIT_TYPE_VERBOSITY");
static TypeVerbosity verbosity = c_verbosity ? static TypeVerbosity verbosity = c_verbosity ?
@ -1407,21 +1427,67 @@ VaryingShape<Stride> TensorType::computeStrideProps(
at::IntArrayRef sizes, at::IntArrayRef sizes,
at::IntArrayRef strides, at::IntArrayRef strides,
bool tensor_contiguity) { bool tensor_contiguity) {
std::vector<size_t> stride_indices(sizes.size()); int n_dim = static_cast<int>(sizes.size());
std::vector<size_t> stride_indices(n_dim);
// Sorting strides in ascending order
// Example:
// Prior to sorting
// Idx: [0, 1, 2, 3]
// sizes: [8, 1, 10, 16]
// Strides: [160, 1, 16, 1]
// After sorting
// Idx: [1, 3, 2, 0]
// sizes: [1, 16, 10, 8]
// Strides: [1, 1, 16, 160]
//
// The logic below follows what TensorIterator uses in its logic:
// 1. Fast_set_up is the short-cut to identify a. channels_last and
// b. contiguous format, which is what we have in the below logic.
// 2. In more generla cases, it does best effort to preserve permutatoin.
if (is_channels_last_strides_2d(sizes, strides) || is_channels_last_strides_3d(sizes, strides)) {
// case 1.a. short cut channels last
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
stride_indices[0] = 1;
stride_indices[n_dim - 1] = 0;
} else if (is_contiguous_strides(sizes, strides)) {
// case 1.b. short cut contiguous
std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
} else {
std::iota(stride_indices.begin(), stride_indices.end(), 0); std::iota(stride_indices.begin(), stride_indices.end(), 0);
// case 2.
std::sort( //
stride_indices.begin(), // For broadcasted dimension where stride is 0, we have to stick to
stride_indices.end(), // TensorIterator behavior in eager, where they introduce an ambiguous
[&strides](const int& a, const int& b) { // comparison result to preserve permutation by best effort.
// break ties in case of unsqueezed dims // For more details, see NOTE: [Computing output strides]
// i.e. (1, 1, 5) auto should_swap = [&](size_t a, size_t b) {
if (strides[a] == strides[b]) { if (strides[a] == 0 || strides[b] == 0) {
return a > b; return 0;
} else if (strides[a] < strides[b]) {
return -1;
} else if (strides[a] > strides[b]) {
return 1;
} else { // strides[a] == strides[b]
if (sizes[a] < sizes[b] || a > b ) {
return 1;
}
}
return 0;
};
for (int i = 1; i < n_dim; i++) {
int dim1 = i;
for (int dim0 = i - 1; dim0 >= 0; dim0--) {
int comparison = should_swap(stride_indices[dim0], stride_indices[dim1]);
if (comparison > 0) {
std::swap(stride_indices[dim0], stride_indices[dim1]);
dim1 = dim0;
} else if (comparison < 0) {
break;
}
}
}
} }
return strides[a] < strides[b];
});
std::vector<Stride> stride_properties; std::vector<Stride> stride_properties;
for (size_t i = 0; i < stride_indices.size(); i++) { for (size_t i = 0; i < stride_indices.size(); i++) {
bool contiguous_ = tensor_contiguity; bool contiguous_ = tensor_contiguity;

View File

@ -42,7 +42,8 @@ list(APPEND ATen_CPU_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/ivalue_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ivalue_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/vmap_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/vmap_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp ${CMAKE_CURRENT_SOURCE_DIR}/type_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/dispatch_key_set_test.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/dispatch_key_set_test.cpp
${CMAKE_CURRENT_SOURCE_DIR}/stride_properties_test.cpp)
list(APPEND ATen_CUDA_TEST_SRCS list(APPEND ATen_CUDA_TEST_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu ${CMAKE_CURRENT_SOURCE_DIR}/cuda_atomic_ops_test.cu

View File

@ -0,0 +1,69 @@
#include <gtest/gtest.h>
#include <ATen/ATen.h>
using namespace at;
// TODO: failing sizes {4, 1, 4, 1}
std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
inline bool CheckStrideIndices(const Tensor& t, at::MemoryFormat format) {
size_t n_dim = t.dim();
std::vector<size_t> stride_indices(n_dim);
if (format == at::MemoryFormat::ChannelsLast) {
// stride_indices_ should be {1, n-1, n-2, ..., 2, 0}
std::iota(stride_indices.rbegin() + 1, stride_indices.rend() - 1, 2);
stride_indices[0] = 1;
stride_indices[n_dim - 1] = 0;
} else if (format == at::MemoryFormat::Contiguous) {
// stride_indices_ should be {n-1, n-2, n-3, ..., 0}
std::iota(stride_indices.rbegin(), stride_indices.rend(), 0);
} else {
TORCH_INTERNAL_ASSERT(false, "not recognized memory format");
}
// testing computeStrideProps with `IValue ival(t)` somehow doesn't work on CI
// with onnx; The function works fine within, but stride properties is somehow
// altered in ival->type()->cast<TensorType>();
auto tt = TensorType::create(c10::nullopt, c10::nullopt, t.sizes(), t.strides(), c10::nullopt);
TORCH_INTERNAL_ASSERT(tt->stride_properties().isComplete(), "complete stride properties is needed for the test");
auto index_iter = stride_indices.begin();
for (const auto& opt_stride : *tt->stride_properties().sizes()) {
if (*index_iter++ != opt_stride->stride_index_.value()) {
return false;
}
}
return true;
}
TEST(StridePropertiesTest, StrideIndicesTest) {
// NOLINTNEXTLINE(performance-for-range-copy)
for (const auto& size : sizes) {
Tensor t = at::rand(size);
for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
t.resize_(size, memory_format);
EXPECT_TRUE(CheckStrideIndices(t, memory_format));
}
}
}
TEST(StridePropertiesTest, ZeroStrideIndicesEagerConsistencyTest) {
auto permuted_tensor = at::rand({6, 3, 1, 5, 2}).permute({0, 3, 2, 1, 4}); // permute dim-1 & dim-3
auto tensor = permuted_tensor.expand({6, 5, 4, 3, 2}); // expand dim-2
auto temp = TensorType::create(c10::nullopt, c10::nullopt, tensor.sizes(), tensor.strides(), c10::nullopt);
// TensorIterator would preserve stride order, this is the eager reference
auto eager_tensor = tensor.relu();
auto ref_type = TensorType::create(c10::nullopt, c10::nullopt, eager_tensor.sizes(), eager_tensor.strides(), c10::nullopt);
TORCH_INTERNAL_ASSERT(temp->stride_properties().isComplete() &&
temp->stride_properties().isComplete(), "complete stride properties is needed for the test");
auto ref_iter = (*(ref_type->stride_properties().sizes())).begin();
for (const auto& opt_stride : *temp->stride_properties().sizes()) {
EXPECT_TRUE(opt_stride->stride_index_.value() == (*ref_iter)->stride_index_.value());
ref_iter++;
}
}

View File

@ -124,8 +124,10 @@ bool complyWith(
if (j != 0 && inner_dim != -1) { if (j != 0 && inner_dim != -1) {
// we are not looking at dim-j, but dim-sorted_index, which // we are not looking at dim-j, but dim-sorted_index, which
// is the j-th fastest dim; // is the j-th fastest dim;
// TODO: merge this with above and put a long comment there // Note: we ignore 0-stride dimension, since eager logic on stride
if (t_strides[sorted_index] < t_strides[inner_dim]) { // indices is ambiguous
if (t_strides[sorted_index] != 0 && t_strides[inner_dim] != 0 &&
t_strides[sorted_index] < t_strides[inner_dim]) {
return false; return false;
} }
} }