mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix at::optional compile problems on Windows CUDA.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10909 Differential Revision: D9516837 Pulled By: gchanan fbshipit-source-id: fad7e3284e74c599b873ebaae2dcdf5013505855
This commit is contained in:
parent
6ce799edd6
commit
7842b6d0f7
|
|
@ -301,6 +301,14 @@ using OptionalBase = typename std::conditional<
|
|||
|
||||
template <class T>
|
||||
class optional : private OptionalBase<T> {
|
||||
|
||||
template <class U> // re-declaration for nvcc on Windows.
|
||||
using OptionalBase = typename std::conditional<
|
||||
std::is_trivially_destructible<U>::value, // if possible
|
||||
constexpr_optional_base<typename std::remove_const<
|
||||
U>::type>, // use base with trivial destructor
|
||||
optional_base<typename std::remove_const<U>::type>>::type;
|
||||
|
||||
static_assert(
|
||||
!std::is_same<typename std::decay<T>::type, nullopt_t>::value,
|
||||
"bad T");
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ list(APPEND ATen_CUDA_TEST_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_optional_test.cu)
|
||||
if (CUDNN_FOUND)
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp)
|
||||
|
|
|
|||
22
aten/src/ATen/test/cuda_optional_test.cu
Normal file
22
aten/src/ATen/test/cuda_optional_test.cu
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#define CATCH_CONFIG_MAIN
|
||||
#include "catch.hpp"
|
||||
|
||||
#include "ATen/ATen.h"
|
||||
#include "ATen/optional.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
TEST_CASE( "optional in cuda files", "[cuda]" ) {
|
||||
at::optional<int64_t> trivially_destructible;
|
||||
at::optional<std::vector<int64_t>> non_trivially_destructible;
|
||||
REQUIRE(!trivially_destructible.has_value());
|
||||
REQUIRE(!non_trivially_destructible.has_value());
|
||||
|
||||
trivially_destructible = {5};
|
||||
non_trivially_destructible = std::vector<int64_t>{5, 10};
|
||||
REQUIRE(trivially_destructible.has_value());
|
||||
REQUIRE(non_trivially_destructible.has_value());
|
||||
}
|
||||
|
||||
|
|
@ -31,6 +31,9 @@ fi
|
|||
if [[ -x ./cuda_half_test ]]; then
|
||||
./cuda_half_test
|
||||
fi
|
||||
if [[ -x ./cuda_optional_test ]]; then
|
||||
./cuda_optional_test
|
||||
fi
|
||||
if [ "$VALGRIND" == "ON" ]
|
||||
then
|
||||
valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user