diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index e042212f11d..33423344b54 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -78,7 +78,7 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) return CUDA_R_64I; case c10::ScalarType::BFloat16: return CUDA_R_16BF; -#if defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11080) || (defined(USE_ROCM) && ROCM_VERSION >= 60300) case c10::ScalarType::Float8_e4m3fn: return CUDA_R_8F_E4M3; case c10::ScalarType::Float8_e5m2: diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index cf3545c5ca9..3f2916862ca 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -139,7 +139,7 @@ void CUDAGraph::capture_end() { // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1g1accfe1da0c605a577c22d9751a09597 // cudaGraphInstantiateWithFlags // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__GRAPH.html#group__CUDART__GRAPH_1ga2c652a24ba93e52b99a47bec0888233 -#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)) +#if ((defined(CUDA_VERSION) && CUDA_VERSION >= 11040) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)) int version = 0; AT_CUDA_CHECK(cudaDriverGetVersion(&version)); if (version < 11040) { @@ -154,7 +154,7 @@ void CUDAGraph::capture_end() { #endif //Since ROCm 6.2, we want to go down this path as hipGraphExecDestroy in the destructor will not immediately free the memory. //It will wait for the next sync operation. cudaGraphInstantiateFlagAutoFreeOnLaunch will add async frees after graph launch. -#if (defined(CUDA_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)) +#if ((defined(CUDA_VERSION) && CUDA_VERSION >= 11040) || (defined(USE_ROCM) && ROCM_VERSION >= 60200)) } else { AT_CUDA_CHECK(cudaGraphInstantiateWithFlags(&graph_exec_, graph_, @@ -216,7 +216,7 @@ void CUDAGraph::enable_debug_mode() { } void CUDAGraph::debug_dump(const std::string& debug_path) { -#if defined(CUDA_VERSION) || defined(USE_ROCM) +#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11030)|| defined(USE_ROCM) if (_cuda_graphs_debug) { TORCH_WARN("DEBUG: calling debug_dump()"); if (has_graph_) { diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index cec013f006d..7a24151df20 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -117,11 +117,15 @@ constexpr const char* _cusolver_backend_suggestion = \ "linear algebra operators with other supported backends. " \ "See https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.preferred_linalg_library"; +// When cuda < 11.5, cusolver raises CUSOLVER_STATUS_EXECUTION_FAILED when input contains nan. // When cuda >= 11.5, cusolver normally finishes execution and sets info array indicating convergence issue. #define TORCH_CUSOLVER_CHECK(EXPR) \ do { \ cusolverStatus_t __err = EXPR; \ - if (__err == CUSOLVER_STATUS_INVALID_VALUE) { \ + if ((CUDA_VERSION < 11500 && \ + __err == CUSOLVER_STATUS_EXECUTION_FAILED) || \ + (CUDA_VERSION >= 11500 && \ + __err == CUSOLVER_STATUS_INVALID_VALUE)) { \ TORCH_CHECK_LINALG( \ false, \ "cusolver error: ", \ diff --git a/aten/src/ATen/cuda/cub.cuh b/aten/src/ATen/cuda/cub.cuh index e1d452bac4a..2de9636fc91 100644 --- a/aten/src/ATen/cuda/cub.cuh +++ b/aten/src/ATen/cuda/cub.cuh @@ -291,7 +291,7 @@ inline void inclusive_scan(InputIteratorT input, OutputIteratorT output, ScanOpT #endif } -# if defined(CUDA_VERSION) || defined(USE_ROCM) +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) template struct BlockPrefixCallbackOp diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index e00726ba534..dcef7a290b1 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -146,8 +146,10 @@ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, NVRTC_STUB1(nvrtcDestroyProgram, nvrtcProgram *) NVRTC_STUB2(nvrtcGetPTXSize, nvrtcProgram, size_t *) NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *) NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *) +#endif NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *) _STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult) NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*) diff --git a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h index d89875865b8..22f71c35dc2 100644 --- a/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h +++ b/aten/src/ATen/cuda/nvrtc_stub/ATenNVRTC.h @@ -76,7 +76,7 @@ namespace at::cuda { AT_FORALL_NVRTC_BASE(_) #endif -#if defined(CUDA_VERSION) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 #define AT_FORALL_NVRTC(_) \ AT_FORALL_NVRTC_EXTENDED(_) \ _(nvrtcGetCUBINSize) \ diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index b3b56eae764..831261c7985 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -359,7 +359,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma bool is_float_output_with_half_input = (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::BFloat16) && result.scalar_type() == at::ScalarType::Float; c10::MaybeOwned self_; if (&result != &self) { -#if defined(CUDA_VERSION) || defined(USE_ROCM) +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM) // Strangely, if mat2 has only 1 row or column, we get // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] @@ -495,6 +495,15 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma } #else auto activation_epilogue = activation_to_gemm_and_blas_arg(activation); +#if (defined(CUDA_VERSION) && (CUDA_VERSION < 11080)) + // GELU is not supported (and does not compile!) prior + // to CUDA 11.4. Have observed accuracy issues with + // GELU epilogue in 11.4; disabling the GELU epilogue + // path for CUDA version < 11.8. + if (activation == Activation::GELU) + activation_epilogue = cuda::blas::GEMMAndBiasActivationEpilogue::None; +#endif + bool okay = true; if (is_float_output_with_half_input) { AT_DISPATCH_REDUCED_FLOATING_TYPES( @@ -637,7 +646,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma // gating activation_to_gemm_and_blas_arg above; here we are manually // performing a post-GELU because we weren't able to use the GELU // epilogue above. -#if !defined(CUDA_VERSION) && !defined(USE_ROCM) +#if !(defined(CUDA_VERSION) && CUDA_VERSION >= 11080) && !defined(USE_ROCM) if (useLtInterface && activation == Activation::GELU) { at::gelu_(const_cast(*args.result), "tanh"); } @@ -1008,7 +1017,7 @@ Tensor& _int_mm_out_cuda(const Tensor& self, const Tensor& mat2, Tensor& result) TORCH_CHECK(result.is_contiguous(), "Expected result to be contiguous."); -#if defined(CUDA_VERSION) || defined(USE_ROCM) +#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11070)) || defined(USE_ROCM) cublasCommonArgs args(self, mat2, result); at::cuda::blas::int8_gemm( diff --git a/aten/src/ATen/native/cuda/MixedDtypesLinear.cu b/aten/src/ATen/native/cuda/MixedDtypesLinear.cu index 06db07bca54..db4cfb4de78 100644 --- a/aten/src/ATen/native/cuda/MixedDtypesLinear.cu +++ b/aten/src/ATen/native/cuda/MixedDtypesLinear.cu @@ -2,7 +2,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) // Doesn't work on ROCm or Windows yet // TODO: Add compiler warning? Add PyTorch config flag? #else @@ -20,7 +20,7 @@ #include #endif -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) // Doesn't work on ROCm or Windows yet #else #define CUTLASS_STATUS_CHECK(status) \ @@ -32,7 +32,7 @@ namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) // Doesn't work on ROCm or Windows yet or old compiler #else template @@ -198,7 +198,7 @@ _mixed_dtypes_linear(const Tensor& input, const Tensor& weight, const Tensor& scale, const std::optional& bias_opt, const std::optional activation_opt) { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, "_mixed_dtypes_linear: not compiled for this platform"); return Tensor{}; #else diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index d61b99fb5a3..8f68dcc969b 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -300,7 +300,7 @@ void nonzero_static_cuda_out_impl( int64_t size, int64_t fill_value, Tensor& out) { -#if defined(CUDA_VERSION) || defined(USE_ROCM) +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) Tensor self_contiguous_ = self.contiguous(); // see comment in nonzero_cuda_out_impl on reqs for out diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu index 3f45c8d7b10..14807c0200e 100644 --- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu +++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu @@ -17,7 +17,7 @@ void addcmul_cuda_scalar_tensor2_kernel( const Scalar& value ); -#if AT_USE_JITERATOR() +#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 constexpr char addcmul_name[] = "addcmul"; #endif void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { @@ -37,7 +37,10 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { - #if AT_USE_JITERATOR() + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { auto alpha = value.to(); static const auto addcmul_string = jiterator_stringify( @@ -90,14 +93,17 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { } } -#if AT_USE_JITERATOR() +#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 constexpr char addcmul_scalar_tensor2_name[] = "addcmul_scalar_tensor2"; #endif void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar& scalar_tensor2, const Scalar& value) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { - #if AT_USE_JITERATOR() + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcmul_cuda", [&]() { auto c = scalar_tensor2.to(); auto alpha = value.to(); @@ -139,14 +145,17 @@ void addcmul_cuda_scalar_tensor2_kernel(TensorIteratorBase& iter, const Scalar& } } -#if AT_USE_JITERATOR() +#if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 // return a + alpha * (b / static_cast(c)); constexpr char addcdiv_name[] = "addcdiv"; #endif void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { - #if AT_USE_JITERATOR() + // When using Jiterator, addcmul and addcdiv kernels get stuck during a + // promotion test on CUDA 11.3, so only enable that from CUDA 11.5: + // https://github.com/pytorch/pytorch/pull/74234#issuecomment-1100932209 + #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 AT_DISPATCH_COMPLEX_TYPES(dtype, "addcdiv_cuda", [&]() { auto alpha = value.to(); static const auto addcdiv_string = diff --git a/aten/src/ATen/native/cuda/ScanUtils.cuh b/aten/src/ATen/native/cuda/ScanUtils.cuh index c4d86acb43e..1bb64473009 100644 --- a/aten/src/ATen/native/cuda/ScanUtils.cuh +++ b/aten/src/ATen/native/cuda/ScanUtils.cuh @@ -453,7 +453,7 @@ void scan_dim(const TensorBase& self, const TensorBase& result, if (self.numel() == self.size(dim)) { if constexpr (std::is_same_v>) { if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms()) && (self.is_floating_point() || self.is_complex())) { -#if defined(CUDA_VERSION) || defined(USE_ROCM) +# if (defined(CUDA_VERSION) && CUDA_VERSION > 11040) || defined(USE_ROCM) cuda::cub::inclusive_deterministic_scan(self_->const_data_ptr(), result.mutable_data_ptr(), binary_op, self.numel()); #else globalContext().alertNotDeterministic("cumsum_cuda_kernel"); diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cuh b/aten/src/ATen/native/cuda/TensorModeKernel.cuh index e9c1bfd6d00..fb43e0d8f34 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cuh +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cuh @@ -193,7 +193,9 @@ __device__ inline void bitonicSortKeys( // dimension as the innermost dim, such that we can get the particular slice for // a Tensor via its linear block dimension * the slice size. template +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11070 __launch_bounds__(1024, 1) +#endif __global__ void compute_mode( const T* input, at::cuda::detail::TensorInfo values, diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 19d04473699..0d49ec9c187 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -912,6 +912,10 @@ void codegenOutputQuery( compile_to_sass = true; } + #if defined(CUDA_VERSION) && CUDA_VERSION < 11010 + // compile to sass is not allowed prior to CUDA 11.1 + compile_to_sass = false; + #endif #endif } @@ -1620,7 +1624,7 @@ NvrtcFunction jit_pwise_function( size_t ptx_size = 0; std::vector ptx; - #if !defined(USE_ROCM) + #if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // compile_to_sass determines whether we are generating SASS or PTX, hence // the different API. const auto getSize = compile_to_sass diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu index 364f33c9066..925a33b0bbd 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu @@ -4,7 +4,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #include #include @@ -12,7 +12,7 @@ namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else struct Params { uint64_t const* threads_masks; @@ -123,7 +123,7 @@ Tensor _sparse_semi_structured_apply_dense( const Tensor& input, const Tensor& threads_masks) { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, "_sparse_semi_structured_apply_dense: not supported"); return Tensor{}; #else diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index 61457a009f5..75d4e8c75c9 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -3,7 +3,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #include #include @@ -16,7 +16,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #define CUTLASS_STATUS_CHECK(status) \ { \ @@ -31,7 +31,7 @@ namespace { namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else // Wrapper function for CUTLASS sparse GEMM implementation, used // solely to simplify dispatching from @@ -613,7 +613,7 @@ Tensor _sparse_semi_structured_linear( "removed in a future PyTorch release. Please use " "_sparse_semi_structured_mm/_sparse_semi_structured_addmm " "instead."); -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, "_sparse_semi_structured_linear: CUTLASS not supported"); return Tensor{}; #else diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu index 45c70f9764b..47a19d26342 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -3,7 +3,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #include #include @@ -16,7 +16,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #define CUTLASS_STATUS_CHECK(status) \ { \ @@ -28,7 +28,7 @@ namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else // Wrapper function for CUTLASS sparse GEMM implementation, used // solely to simplify dispatching from @@ -526,7 +526,7 @@ Tensor sparse_semi_structured_mad_op( const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2, const std::optional& input_opt, const Scalar& alpha, const Scalar& beta, const std::optional out_dtype_opt) { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, __func__, " : CUTLASS not supported"); return Tensor{}; #else @@ -818,7 +818,7 @@ Tensor _sparse_semi_structured_addmm( // Following is just for testing purposes. namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else // Copied from tools/util/include/host_reorder.h, from CUTLASS source // tree. This is for simplicity - namely, this file is not under @@ -856,7 +856,7 @@ static void reorder_meta(cutlass::TensorRef dest, std::tuple _to_sparse_semi_structured(const Tensor& dense) { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, __func__, " : CUTLASS not supported"); return std::make_tuple(Tensor{}, Tensor{}); #else diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu index 599a59a1831..5f94e013f3f 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu @@ -8,7 +8,7 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #include #include @@ -17,7 +17,7 @@ namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else struct MetadataCuSparseLt { // Format used by cuSparseLt @@ -280,7 +280,7 @@ std::tuple _sparse_semi_structured_tile( std::string_view algorithm, bool use_cutlass) { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, "_sparse_semi_structured_tile: not supported"); return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}); #else diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu index 5ead412bd2d..9b9b1bc0cc6 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu @@ -5,14 +5,14 @@ #include #include -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else #include #endif namespace at::native { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) #else template __global__ void __launch_bounds__(32 /* num_threads */) @@ -89,7 +89,7 @@ std::tuple _sparse_semi_structured_apply_typed(Tensor input, Ten std::tuple _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile` { -#if defined(USE_ROCM) || defined(_MSC_VER) +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) TORCH_CHECK(false, "_sparse_semi_structured_apply: not supported"); return std::make_tuple(Tensor{}, Tensor{}); #else diff --git a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp index 7f02f46411c..b9e25430421 100644 --- a/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp @@ -125,7 +125,7 @@ FusedKernelCUDA::FusedKernelCUDA( args.push_back("-hip-pch"); #else const std::string compute = std::string("--gpu-architecture=") + -#if defined(CUDA_VERSION) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) // which gives better backwards compatibility to work on older driver, // (since older driver doesn't necessrily recognize PTX emitted by new @@ -156,7 +156,7 @@ FusedKernelCUDA::FusedKernelCUDA( [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); AT_CUDA_NVRTC_CHECK(result); size_t ptx_size = 0; -#if !defined(USE_ROCM) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // compile_to_sass determines whether we are generating SASS or PTX, hence // the different API. const auto getSize = compile_to_sass diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp index 55ec8376f82..89638d5a2de 100644 --- a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -1286,7 +1286,7 @@ void CudaCodeGen::CompileToNVRTC( args.push_back("-hip-pch"); #else const std::string compute = std::string("--gpu-architecture=") + -#if defined(CUDA_VERSION) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // CUDA 11.1 allows going directly to SASS (sm_) instead of PTX (compute_) // which gives better backwards compatibility to work on older driver, // (since older driver doesn't necessarily recognize PTX emitted by new @@ -1321,7 +1321,7 @@ void CudaCodeGen::CompileToNVRTC( AT_CUDA_NVRTC_CHECK(result); size_t ptx_size = 0; std::vector ptx; -#if !defined(USE_ROCM) +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11010 // compile_to_sass determines whether we are generating SASS or PTX, hence // the different API. auto getSize = compile_to_sass