Fix & unit test for c10::ArrayRef constructed from user-defined types (#139758)

Fixes #139391

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139758
Approved by: https://github.com/ezyang
This commit is contained in:
Judicaël Clair 2024-11-06 04:23:01 +00:00 committed by PyTorch MergeBot
parent d35a600b74
commit 1c63612567
2 changed files with 48 additions and 3 deletions

View File

@ -0,0 +1,45 @@
#include <c10/util/ArrayRef.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <utility>
#include <vector>
namespace {
template <typename T>
class ctor_from_container_test_span_ {
T* data_;
std::size_t sz_;
public:
template <typename V = std::vector<std::remove_const_t<T>>>
constexpr explicit ctor_from_container_test_span_(
std::conditional_t<std::is_const_v<T>, const V, V>& vec) noexcept
: data_(vec.data()), sz_(vec.size()) {}
[[nodiscard]] constexpr auto data() const noexcept {
return data_;
}
[[nodiscard]] constexpr auto size() const noexcept {
return sz_;
}
};
TEST(ArrayRefTest, ctor_from_container_test) {
using value_type = int;
std::vector<value_type> test_vec{1, 6, 32, 4, 68, 3, 7};
const ctor_from_container_test_span_<value_type> test_mspan{test_vec};
const ctor_from_container_test_span_<const value_type> test_cspan{
std::as_const(test_vec)};
const auto test_ref_mspan = c10::ArrayRef<value_type>(test_mspan);
const auto test_ref_cspan = c10::ArrayRef<value_type>(test_cspan);
EXPECT_EQ(std::as_const(test_vec), test_ref_mspan);
EXPECT_EQ(std::as_const(test_vec), test_ref_cspan);
}
} // namespace

View File

@ -98,9 +98,9 @@ class ArrayRef final {
template <
typename Container,
typename = std::enable_if_t<std::is_same_v<
std::remove_const_t<decltype(std::declval<Container>().data())>,
T*>>>
typename U = decltype(std::declval<Container>().data()),
typename = std::enable_if_t<
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
/* implicit */ ArrayRef(const Container& container)
: Data(container.data()), Length(container.size()) {
debugCheckNullptrInvariant();