mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable Windows Arm64 (#133088)
This PR enables Pytorch for Windows on Arm64 - CPU only. Currently, there aren't any checks in place to build and test for Windows on Arm64, but we're working to implement those as soon as possible. We recommend using [Arm Performance Libraries (APL)](https://developer.arm.com/Tools%20and%20Software/Arm%20Performance%20Libraries) as a BLAS option, which is introduced in this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133088 Approved by: https://github.com/malfet Co-authored-by: cristian panaite <panaite.cristian2000@gmail.com> Co-authored-by: Stefan-Alin Pahontu <56953855+alinpahontu2912@users.noreply.github.com> Co-authored-by: Ozan Aydin <148207261+ozanMSFT@users.noreply.github.com>
This commit is contained in:
parent
f7bb11dcc2
commit
b021486405
|
|
@ -428,7 +428,7 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
|
|||
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
|
||||
endif()
|
||||
|
||||
if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE AND NOT (MSVC AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64"))
|
||||
if(NOT MSVC)
|
||||
# Bump up optimization level for sleef to -O1, since at -O0 the compiler
|
||||
# excessively spills intermediate vector registers to the stack
|
||||
|
|
|
|||
|
|
@ -132,11 +132,46 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *inf
|
|||
extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
|
||||
|
||||
// potrs
|
||||
#if defined(_WIN32) && defined(_M_ARM64)
|
||||
|
||||
// The functions zpotrs, cpotrs, dpotrs, and spotrs are not directly available in LAPACKE on Windows on ARM,
|
||||
// so we need to have wrapper functions to call them.
|
||||
// The issue on ARM platform can be found below:
|
||||
// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions
|
||||
|
||||
#define LAPACK_COL_MAJOR 102
|
||||
#define LAPACK_ROW_MAJOR 101
|
||||
|
||||
extern "C" int LAPACKE_zpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex<double> *a, int lda, std::complex<double> *b, int ldb);
|
||||
extern "C" int LAPACKE_cpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex<float> *a, int lda, std::complex<float> *b, int ldb);
|
||||
extern "C" int LAPACKE_dpotrs(int matrix_layout, char uplo, int n, int nrhs, const double *a, int lda, double *b, int ldb);
|
||||
extern "C" int LAPACKE_spotrs(int matrix_layout, char uplo, int n, int nrhs, const float *a, int lda, float *b, int ldb);
|
||||
|
||||
static inline void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info) {
|
||||
*info = LAPACKE_zpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
|
||||
}
|
||||
|
||||
static inline void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info) {
|
||||
*info = LAPACKE_cpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
|
||||
}
|
||||
|
||||
static inline void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info){
|
||||
*info = LAPACKE_dpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
|
||||
}
|
||||
|
||||
static inline void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info) {
|
||||
*info = LAPACKE_spotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex<double> *a, int *lda, std::complex<double> *b, int *ldb, int *info);
|
||||
extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex<float> *a, int *lda, std::complex<float> *b, int *ldb, int *info);
|
||||
extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
|
||||
extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
|
||||
|
||||
#endif
|
||||
|
||||
// potrf
|
||||
extern "C" void zpotrf_(char *uplo, int *n, std::complex<double> *a, int *lda, int *info);
|
||||
extern "C" void cpotrf_(char *uplo, int *n, std::complex<float> *a, int *lda, int *info);
|
||||
|
|
@ -284,11 +319,39 @@ extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau
|
|||
extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
|
||||
|
||||
// ormqr
|
||||
#if defined(_WIN32) && defined(_M_ARM64)
|
||||
|
||||
// The functions zunmqr, cunmqr, dormqr, and sormqr are not directly available in LAPACKE on Windows on ARM,
|
||||
// so we need to have wrapper functions to call them.
|
||||
// The issue on ARM platform can be found below:
|
||||
// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions
|
||||
|
||||
extern "C" int LAPACKE_zunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex<double> *a, int lda, const std::complex<double> *tau, std::complex<double> *c, int ldc, std::complex<double> *work, int lwork);
|
||||
extern "C" int LAPACKE_cunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex<float> *a, int lda, const std::complex<float> *tau, std::complex<float> *c, int ldc, std::complex<float> *work, int lwork);
|
||||
extern "C" int LAPACKE_dormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const double *a, int lda, const double *tau, double *c, int ldc, double *work, int lwork);
|
||||
extern "C" int LAPACKE_sormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const float *a, int lda, const float *tau, float *c, int ldc, float *work, int lwork);
|
||||
|
||||
static inline void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info) {
|
||||
*info = LAPACKE_zunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
|
||||
}
|
||||
|
||||
static inline void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info) {
|
||||
*info = LAPACKE_cunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
|
||||
}
|
||||
|
||||
static inline void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info) {
|
||||
*info = LAPACKE_dormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
|
||||
}
|
||||
|
||||
static inline void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info) {
|
||||
*info = LAPACKE_sormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork);
|
||||
}
|
||||
#else
|
||||
extern "C" void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<double> *a, int *lda, std::complex<double> *tau, std::complex<double> *c, int *ldc, std::complex<double> *work, int *lwork, int *info);
|
||||
extern "C" void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex<float> *a, int *lda, std::complex<float> *tau, std::complex<float> *c, int *ldc, std::complex<float> *work, int *lwork, int *info);
|
||||
extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
|
||||
extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
|
||||
|
||||
#endif
|
||||
// syevd
|
||||
extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex<double> *a, int *lda, double *w, std::complex<double> *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info);
|
||||
extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex<float> *a, int *lda, float *w, std::complex<float> *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info);
|
||||
|
|
|
|||
|
|
@ -1719,7 +1719,10 @@ if(BUILD_TEST)
|
|||
endif()
|
||||
else()
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main)
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main)
|
||||
if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64")
|
||||
target_link_libraries(${test_name}_${CPU_CAPABILITY} sleef)
|
||||
endif()
|
||||
endif()
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ else()
|
|||
set(AT_MKLDNN_ENABLED 0)
|
||||
set(AT_MKL_ENABLED 0)
|
||||
endif()
|
||||
set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib")
|
||||
set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib;APL")
|
||||
message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS})
|
||||
|
||||
if(BLAS STREQUAL "Eigen")
|
||||
|
|
@ -226,6 +226,12 @@ elseif(BLAS STREQUAL "FlexiBLAS")
|
|||
find_package(FlexiBLAS REQUIRED)
|
||||
include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR})
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB})
|
||||
elseif(BLAS STREQUAL "APL")
|
||||
find_package(APL REQUIRED)
|
||||
include_directories(SYSTEM ${APL_INCLUDE_DIR})
|
||||
set(BLAS_INFO "apl")
|
||||
set(BLAS_FOUND 1)
|
||||
set(BLAS_LIBRARIES ${APL_LIBRARIES})
|
||||
elseif(BLAS STREQUAL "Generic")
|
||||
# On Debian family, the CBLAS ABIs have been merged into libblas.so
|
||||
if(ENV{GENERIC_BLAS_LIBRARIES} STREQUAL "")
|
||||
|
|
@ -246,7 +252,7 @@ endif()
|
|||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(AT_MKL_SEQUENTIAL 0)
|
||||
set(USE_BLAS 1)
|
||||
if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND))
|
||||
if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND OR APL_FOUND))
|
||||
message(WARNING "Preferred BLAS (" ${BLAS} ") cannot be found, now searching for a general BLAS library")
|
||||
find_package(BLAS)
|
||||
if(NOT BLAS_FOUND)
|
||||
|
|
|
|||
58
cmake/Modules/FindAPL.cmake
Normal file
58
cmake/Modules/FindAPL.cmake
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
# - Find APL (Arm Performance Libraries)
|
||||
#
|
||||
# This module sets the following variables:
|
||||
# APL_INCLUDE_SEARCH_PATHS - list of paths to search for APL include files
|
||||
# APL_LIB_SEARCH_PATHS - list of paths to search for APL libraries
|
||||
# APL_FOUND - set to true if APL is found
|
||||
# APL_INCLUDE_DIR - path to include dir.
|
||||
# APL_LIB_DIR - path to include dir.
|
||||
# APL_LIBRARIES - list of libraries for base APL
|
||||
|
||||
SET(APL_INCLUDE_SEARCH_PATHS $ENV{ARMPL_DIR}/include)
|
||||
SET(APL_LIB_SEARCH_PATHS $ENV{ARMPL_DIR}/lib)
|
||||
|
||||
SET(APL_FOUND ON)
|
||||
|
||||
# Check include file
|
||||
FIND_PATH(APL_INCLUDE_DIR NAMES armpl.h PATHS ${APL_INCLUDE_SEARCH_PATHS})
|
||||
IF(NOT APL_INCLUDE_DIR)
|
||||
SET(APL_FOUND OFF)
|
||||
MESSAGE(STATUS "Could not verify APL include directory. Turning APL_FOUND off")
|
||||
ENDIF()
|
||||
|
||||
# Check lib file
|
||||
FIND_PATH(APL_LIB_DIR NAMES libarmpl_lp64_mp.dll.lib libomp.dll.lib libarmpl_lp64_mp.a PATHS ${APL_LIB_SEARCH_PATHS})
|
||||
IF(NOT APL_LIB_DIR)
|
||||
SET(APL_FOUND OFF)
|
||||
MESSAGE(STATUS "Could not verify APL lib directory. Turning APL_FOUND off")
|
||||
ENDIF()
|
||||
|
||||
IF (APL_FOUND)
|
||||
IF(WIN32)
|
||||
set(APL_LIBRARIES
|
||||
"${APL_LIB_DIR}/libarmpl_lp64_mp.dll.lib"
|
||||
"${APL_LIB_DIR}/libomp.dll.lib"
|
||||
)
|
||||
ELSEIF(UNIX)
|
||||
set(APL_LIBRARIES
|
||||
"${APL_LIB_DIR}/libarmpl_lp64_mp.a"
|
||||
)
|
||||
ENDIF()
|
||||
MESSAGE(STATUS "Found APL header: ${APL_INCLUDE_DIR}")
|
||||
MESSAGE(STATUS "Found APL library: ${APL_LIB_DIR}")
|
||||
message(STATUS "APL_LIBRARIES: ${APL_LIBRARIES}")
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${APL_LIBRARIES})
|
||||
include(CheckCSourceRuns)
|
||||
CHECK_C_SOURCE_RUNS("
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
float x[4] = { 1, 2, 3, 4 };
|
||||
float y[4] = { .1, .01, .001, .0001 };
|
||||
extern float cblas_sdot();
|
||||
int main() {
|
||||
int i;
|
||||
double r = cblas_sdot(4, x, 1, y, 1);
|
||||
exit((float)r != (float).1234);
|
||||
}" BLAS_USE_CBLAS_DOT )
|
||||
MESSAGE(STATUS "BLAS_USE_CBLAS_DOT: ${BLAS_USE_CBLAS_DOT}")
|
||||
ENDIF (APL_FOUND)
|
||||
|
|
@ -223,6 +223,34 @@ if(BLAS_FOUND)
|
|||
endif(LAPACK_LIBRARIES)
|
||||
endif()
|
||||
|
||||
#Arm Performance Libraries
|
||||
IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "apl"))
|
||||
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
|
||||
check_function_exists("cheev_" APL_LAPACK_WORKS)
|
||||
if(APL_LAPACK_WORKS)
|
||||
check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS)
|
||||
if(NOT LAPACK_CGESDD_WORKS)
|
||||
find_library(GFORTRAN_LIBRARY
|
||||
NAMES libgfortran.a gfortran
|
||||
PATHS ${CMAKE_C_IMPLICIT_LINK_DIRECTORIES})
|
||||
list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}")
|
||||
unset(LAPACK_CGESDD_WORKS CACHE)
|
||||
check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS)
|
||||
if(LAPACK_CGESDD_WORKS)
|
||||
list(APPEND LAPACK_LIBRARIES "${GFORTRAN_LIBRARY}")
|
||||
else()
|
||||
message(WARNING "APL has been compiled with Lapack support, but cgesdd can not be used")
|
||||
set(APL_LAPACK_WORKS NO)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
set(CMAKE_REQUIRED_LIBRARIES)
|
||||
if(APL_LAPACK_WORKS)
|
||||
SET(LAPACK_INFO "apl")
|
||||
else()
|
||||
message(STATUS "It seems APL has not been compiled with Lapack support")
|
||||
endif()
|
||||
endif()
|
||||
else(BLAS_FOUND)
|
||||
message(STATUS "LAPACK requires BLAS")
|
||||
endif(BLAS_FOUND)
|
||||
|
|
|
|||
|
|
@ -779,7 +779,7 @@ def _get_torch_related_args(
|
|||
if not aot_mode:
|
||||
libraries.append("torch_python")
|
||||
|
||||
if _IS_WINDOWS:
|
||||
if _IS_WINDOWS and platform.machine().lower() != "arm64":
|
||||
libraries.append("sleef")
|
||||
|
||||
return include_dirs, libraries_dirs, libraries
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import glob
|
|||
import importlib
|
||||
import importlib.abc
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
|
|
@ -994,7 +995,7 @@ def CppExtension(name, sources, *args, **kwargs):
|
|||
libraries.append('torch')
|
||||
libraries.append('torch_cpu')
|
||||
libraries.append('torch_python')
|
||||
if IS_WINDOWS:
|
||||
if IS_WINDOWS and platform.machine().lower() != "arm64":
|
||||
libraries.append("sleef")
|
||||
|
||||
kwargs['libraries'] = libraries
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user