mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix & unit test for c10::ArrayRef constructed from user-defined types (#139758)
Fixes #139391 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139758 Approved by: https://github.com/ezyang
This commit is contained in:
parent
d35a600b74
commit
1c63612567
45
c10/test/util/ArrayRef_test.cpp
Normal file
45
c10/test/util/ArrayRef_test.cpp
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#include <c10/util/ArrayRef.h>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
class ctor_from_container_test_span_ {
|
||||
T* data_;
|
||||
std::size_t sz_;
|
||||
|
||||
public:
|
||||
template <typename V = std::vector<std::remove_const_t<T>>>
|
||||
constexpr explicit ctor_from_container_test_span_(
|
||||
std::conditional_t<std::is_const_v<T>, const V, V>& vec) noexcept
|
||||
: data_(vec.data()), sz_(vec.size()) {}
|
||||
|
||||
[[nodiscard]] constexpr auto data() const noexcept {
|
||||
return data_;
|
||||
}
|
||||
|
||||
[[nodiscard]] constexpr auto size() const noexcept {
|
||||
return sz_;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ArrayRefTest, ctor_from_container_test) {
|
||||
using value_type = int;
|
||||
std::vector<value_type> test_vec{1, 6, 32, 4, 68, 3, 7};
|
||||
const ctor_from_container_test_span_<value_type> test_mspan{test_vec};
|
||||
const ctor_from_container_test_span_<const value_type> test_cspan{
|
||||
std::as_const(test_vec)};
|
||||
|
||||
const auto test_ref_mspan = c10::ArrayRef<value_type>(test_mspan);
|
||||
const auto test_ref_cspan = c10::ArrayRef<value_type>(test_cspan);
|
||||
|
||||
EXPECT_EQ(std::as_const(test_vec), test_ref_mspan);
|
||||
EXPECT_EQ(std::as_const(test_vec), test_ref_cspan);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
@ -98,9 +98,9 @@ class ArrayRef final {
|
|||
|
||||
template <
|
||||
typename Container,
|
||||
typename = std::enable_if_t<std::is_same_v<
|
||||
std::remove_const_t<decltype(std::declval<Container>().data())>,
|
||||
T*>>>
|
||||
typename U = decltype(std::declval<Container>().data()),
|
||||
typename = std::enable_if_t<
|
||||
(std::is_same_v<U, T*> || std::is_same_v<U, T const*>)>>
|
||||
/* implicit */ ArrayRef(const Container& container)
|
||||
: Data(container.data()), Length(container.size()) {
|
||||
debugCheckNullptrInvariant();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user