mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Summary: Skip building extensions if windows following https://github.com/pytorch/pytorch/pull/67161#issuecomment-958062611 Related issue: https://github.com/pytorch/pytorch/issues/67073 cc ngimel xwang233 ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/67735 Reviewed By: bdhirsh Differential Revision: D32141250 Pulled By: ngimel fbshipit-source-id: 9bfdb7cf694c99f6fc8cbe9033a12429b6e4b6fe
This commit is contained in:
parent
8b0c2c18eb
commit
7c739e1ab9
|
|
@ -5,6 +5,7 @@
|
|||
#include <ATen/cuda/CUDABlas.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#define CUDABLAS_POSINT_CHECK(FD, X) \
|
||||
TORCH_CHECK( \
|
||||
|
|
@ -96,7 +97,7 @@ namespace at {
|
|||
namespace cuda {
|
||||
namespace blas {
|
||||
|
||||
const char* _cublasGetErrorEnum(cublasStatus_t error) {
|
||||
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error) {
|
||||
if (error == CUBLAS_STATUS_SUCCESS) {
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDASolver.h>
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#ifdef CUDART_VERSION
|
||||
|
||||
|
|
@ -9,7 +10,7 @@ namespace at {
|
|||
namespace cuda {
|
||||
namespace solver {
|
||||
|
||||
const char* cusolverGetErrorMessage(cusolverStatus_t status) {
|
||||
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status) {
|
||||
switch (status) {
|
||||
case CUSOLVER_STATUS_SUCCESS: return "CUSOLVER_STATUS_SUCCES";
|
||||
case CUSOLVER_STATUS_NOT_INITIALIZED: return "CUSOLVER_STATUS_NOT_INITIALIZED";
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <cublas_v2.h>
|
||||
#include <cusparse.h>
|
||||
#include <c10/macros/Export.h>
|
||||
|
||||
#ifdef CUDART_VERSION
|
||||
#include <cusolver_common.h>
|
||||
|
|
@ -39,7 +40,7 @@ class CuDNNError : public c10::Error {
|
|||
} while (0)
|
||||
|
||||
namespace at { namespace cuda { namespace blas {
|
||||
const char* _cublasGetErrorEnum(cublasStatus_t error);
|
||||
C10_EXPORT const char* _cublasGetErrorEnum(cublasStatus_t error);
|
||||
}}} // namespace at::cuda::blas
|
||||
|
||||
#define TORCH_CUDABLAS_CHECK(EXPR) \
|
||||
|
|
@ -66,7 +67,7 @@ const char *cusparseGetErrorString(cusparseStatus_t status);
|
|||
#ifdef CUDART_VERSION
|
||||
|
||||
namespace at { namespace cuda { namespace solver {
|
||||
const char* cusolverGetErrorMessage(cusolverStatus_t status);
|
||||
C10_EXPORT const char* cusolverGetErrorMessage(cusolverStatus_t status);
|
||||
}}} // namespace at::cuda::solver
|
||||
|
||||
#define TORCH_CUSOLVER_CHECK(EXPR) \
|
||||
|
|
|
|||
17
test/cpp_extensions/cublas_extension.cpp
Normal file
17
test/cpp_extensions/cublas_extension.cpp
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#include <iostream>
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cublas_v2.h>
|
||||
|
||||
torch::Tensor noop_cublas_function(torch::Tensor x) {
|
||||
cublasHandle_t handle;
|
||||
TORCH_CUDABLAS_CHECK(cublasCreate(&handle));
|
||||
TORCH_CUDABLAS_CHECK(cublasDestroy(handle));
|
||||
return x;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("noop_cublas_function", &noop_cublas_function, "a cublas function");
|
||||
}
|
||||
17
test/cpp_extensions/cusolver_extension.cpp
Normal file
17
test/cpp_extensions/cusolver_extension.cpp
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cusolverDn.h>
|
||||
|
||||
|
||||
torch::Tensor noop_cusolver_function(torch::Tensor x) {
|
||||
cusolverDnHandle_t handle;
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnCreate(&handle));
|
||||
TORCH_CUSOLVER_CHECK(cusolverDnDestroy(handle));
|
||||
return x;
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("noop_cusolver_function", &noop_cusolver_function, "a cusolver function");
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ import os
|
|||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS
|
||||
|
||||
if sys.platform == 'win32':
|
||||
vc_version = os.getenv('VCToolsVersion', '')
|
||||
|
|
@ -48,6 +49,20 @@ if torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None
|
|||
'nvcc': ['-O2']})
|
||||
ext_modules.append(extension)
|
||||
|
||||
# todo(mkozuki): Figure out the root cause
|
||||
if (not IS_WINDOWS) and torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
cublas_extension = CUDAExtension(
|
||||
name='torch_test_cpp_extension.cublas_extension',
|
||||
sources=['cublas_extension.cpp']
|
||||
)
|
||||
ext_modules.append(cublas_extension)
|
||||
|
||||
cusolver_extension = CUDAExtension(
|
||||
name='torch_test_cpp_extension.cusolver_extension',
|
||||
sources=['cusolver_extension.cpp']
|
||||
)
|
||||
ext_modules.append(cusolver_extension)
|
||||
|
||||
setup(
|
||||
name='torch_test_cpp_extension',
|
||||
packages=['torch_test_cpp_extension'],
|
||||
|
|
|
|||
|
|
@ -82,6 +82,26 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
# 2 * sigmoid(0) = 2 * 0.5 = 1
|
||||
self.assertEqual(z, torch.ones_like(z))
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cublas_extension(self):
|
||||
from torch_test_cpp_extension import cublas_extension
|
||||
|
||||
x = torch.zeros(100, device="cuda", dtype=torch.float32)
|
||||
z = cublas_extension.noop_cublas_function(x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@common.skipIfRocm
|
||||
@unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not found")
|
||||
def test_cusolver_extension(self):
|
||||
from torch_test_cpp_extension import cusolver_extension
|
||||
|
||||
x = torch.zeros(100, device="cuda", dtype=torch.float32)
|
||||
z = cusolver_extension.noop_cusolver_function(x)
|
||||
self.assertEqual(z, x)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Not available on Windows")
|
||||
def test_no_python_abi_suffix_sets_the_correct_library_name(self):
|
||||
# For this test, run_test.py will call `python setup.py install` in the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user